Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 39 additions & 2 deletions mockgen/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package main

import (
"errors"
"flag"
"fmt"
"go/ast"
"go/build"
Expand All @@ -29,6 +30,7 @@ import (
"os"
"path"
"path/filepath"
"sort"
"strconv"
"strings"

Expand Down Expand Up @@ -61,6 +63,20 @@ func sourceMode(source string) (*model.Package, error) {
srcDir: srcDir,
}

// positional interface names -> include set
ifaces := flag.Args()
if len(ifaces) > 0 {
p.includeNamesSet = make(map[string]struct{}, len(ifaces))
for _, arg := range ifaces {
for _, name := range strings.Split(arg, ",") {
name = strings.TrimSpace(name)
if name != "" {
p.includeNamesSet[name] = struct{}{}
}
}
}
}

// Handle -imports.
dotImports := make(map[string]bool)
if *imports != "" {
Expand Down Expand Up @@ -92,6 +108,7 @@ func sourceMode(source string) (*model.Package, error) {
for pkgPath := range dotImports {
pkg.DotImports = append(pkg.DotImports, pkgPath)
}

return pkg, nil
}

Expand Down Expand Up @@ -168,6 +185,7 @@ type fileParser struct {
auxInterfaces *interfaceCache
srcDir string
excludeNamesSet map[string]struct{}
includeNamesSet map[string]struct{}
}

func (p *fileParser) errorf(pos token.Pos, format string, args ...any) error {
Expand Down Expand Up @@ -228,17 +246,36 @@ func (p *fileParser) parseFile(importPath string, file *ast.File) (*model.Packag

var is []*model.Interface
for ni := range iterInterfaces(file) {
if _, ok := p.excludeNamesSet[ni.name.String()]; ok {
name := ni.name.String()

if _, ok := p.excludeNamesSet[name]; ok {
continue
}
i, err := p.parseInterface(ni.name.String(), importPath, ni)

if len(p.includeNamesSet) > 0 {
if _, ok := p.includeNamesSet[name]; !ok {
continue
}
}

i, err := p.parseInterface(name, importPath, ni)
if errors.Is(err, errConstraintInterface) {
continue
}
if err != nil {
return nil, err
}
is = append(is, i)

delete(p.includeNamesSet, name)
}
if len(p.includeNamesSet) > 0 {
missing := make([]string, 0, len(p.includeNamesSet))
for n := range p.includeNamesSet {
missing = append(missing, n)
}
sort.Strings(missing)
return nil, fmt.Errorf("requested interfaces not found: %s", strings.Join(missing, ", "))
}
return &model.Package{
Name: file.Name.String(),
Expand Down
82 changes: 82 additions & 0 deletions mockgen/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import (
"go/parser"
"go/token"
"strings"
"testing"
)

Expand Down Expand Up @@ -143,3 +144,84 @@ func TestParseArrayWithConstLength(t *testing.T) {
}
}
}

func TestParseFile_IncludeOnlyRequested(t *testing.T) {
fs := token.NewFileSet()
file, err := parser.ParseFile(fs, "internal/tests/custom_package_name/greeter/greeter.go", nil, 0)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

p := fileParser{
fileSet: fs,
imports: make(map[string]importedPackage),
importedInterfaces: newInterfaceCache(),
// include только один интерфейс
includeNamesSet: map[string]struct{}{"InputMaker": {}},
}

pkg, err := p.parseFile("", file)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

if len(pkg.Interfaces) != 1 || pkg.Interfaces[0].Name != "InputMaker" {
t.Fatalf("Expected only InputMaker, got %v", pkg.Interfaces)
}
}

func TestParseFile_IncludeMissing_ReturnsError(t *testing.T) {
fs := token.NewFileSet()
file, err := parser.ParseFile(fs, "internal/tests/custom_package_name/greeter/greeter.go", nil, 0)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

p := fileParser{
fileSet: fs,
imports: make(map[string]importedPackage),
importedInterfaces: newInterfaceCache(),
includeNamesSet: map[string]struct{}{"DoesNotExist": {}},
}

_, err = p.parseFile("", file)
if err == nil || !strings.Contains(err.Error(), "requested interfaces not found") {
t.Fatalf("Expected missing interface error, got %v", err)
}
}

func TestParseFile_IncludeWithDuplicates_Dedupes(t *testing.T) {
fs := token.NewFileSet()
file, err := parser.ParseFile(fs, "internal/tests/custom_package_name/greeter/greeter.go", nil, 0)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

// Эмулируем «случайно указали дубликаты» как это делает sourceMode (через позиционные аргументы)
args := []string{"InputMaker", "InputMaker"} // дубликаты
include := make(map[string]struct{})
for _, a := range args {
for _, name := range strings.Split(a, ",") {
name = strings.TrimSpace(name)
if name != "" {
include[name] = struct{}{}
}
}
}

p := fileParser{
fileSet: fs,
imports: make(map[string]importedPackage),
importedInterfaces: newInterfaceCache(),
includeNamesSet: include,
}

pkg, err := p.parseFile("", file)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

if len(pkg.Interfaces) != 1 || pkg.Interfaces[0].Name != "InputMaker" {
t.Fatalf("Expected only InputMaker once, got %v", pkg.Interfaces)
}
}
Loading