diff --git a/mockgen/internal/tests/typed/bugreport.go b/mockgen/internal/tests/typed/bugreport.go index 49fc286..60a916a 100644 --- a/mockgen/internal/tests/typed/bugreport.go +++ b/mockgen/internal/tests/typed/bugreport.go @@ -1,6 +1,6 @@ package typed -//go:generate mockgen -typed -aux_files faux=faux/faux.go -destination bugreport_mock.go -package typed -source=bugreport.go Example +//go:generate mockgen -typed -aux_files faux=faux/faux.go -destination bugreport_mock.go -package typed -source=bugreport.go Source import ( "log" diff --git a/mockgen/internal/tests/typed/bugreport_mock.go b/mockgen/internal/tests/typed/bugreport_mock.go index 5c92812..038811e 100644 --- a/mockgen/internal/tests/typed/bugreport_mock.go +++ b/mockgen/internal/tests/typed/bugreport_mock.go @@ -3,7 +3,7 @@ // // Generated by this command: // -// mockgen -typed -aux_files faux=faux/faux.go -destination bugreport_mock.go -package typed -source=bugreport.go Example +// mockgen -typed -aux_files faux=faux/faux.go -destination bugreport_mock.go -package typed -source=bugreport.go Source // // Package typed is a generated GoMock package. diff --git a/mockgen/parse.go b/mockgen/parse.go index f43321c..da0cc5e 100644 --- a/mockgen/parse.go +++ b/mockgen/parse.go @@ -18,6 +18,7 @@ package main import ( "errors" + "flag" "fmt" "go/ast" "go/build" @@ -61,6 +62,18 @@ func sourceMode(source string) (*model.Package, error) { srcDir: srcDir, } + // positional interface names -> include set + if flag.NArg() > 1 { + return nil, errors.New("-source mode accepts at most one argument") + } + if flag.NArg() == 1 { + ifaces := strings.Split(flag.Arg(0), ",") + p.includeNamesSet = make(map[string]struct{}, len(ifaces)) + for _, name := range ifaces { + p.includeNamesSet[name] = struct{}{} + } + } + // Handle -imports. dotImports := make(map[string]bool) if *imports != "" { @@ -92,6 +105,7 @@ func sourceMode(source string) (*model.Package, error) { for pkgPath := range dotImports { pkg.DotImports = append(pkg.DotImports, pkgPath) } + return pkg, nil } @@ -168,6 +182,7 @@ type fileParser struct { auxInterfaces *interfaceCache srcDir string excludeNamesSet map[string]struct{} + includeNamesSet map[string]struct{} // empty to include all } func (p *fileParser) errorf(pos token.Pos, format string, args ...any) error { @@ -228,10 +243,20 @@ func (p *fileParser) parseFile(importPath string, file *ast.File) (*model.Packag var is []*model.Interface for ni := range iterInterfaces(file) { - if _, ok := p.excludeNamesSet[ni.name.String()]; ok { + name := ni.name.String() + + if _, ok := p.excludeNamesSet[name]; ok { continue } - i, err := p.parseInterface(ni.name.String(), importPath, ni) + + // All interfaces are included if no filter was specified. + if len(p.includeNamesSet) > 0 { + if _, ok := p.includeNamesSet[name]; !ok { + continue + } + } + + i, err := p.parseInterface(name, importPath, ni) if errors.Is(err, errConstraintInterface) { continue } @@ -239,7 +264,10 @@ func (p *fileParser) parseFile(importPath string, file *ast.File) (*model.Packag return nil, err } is = append(is, i) + + delete(p.includeNamesSet, name) } + return &model.Package{ Name: file.Name.String(), PkgPath: importPath, diff --git a/mockgen/parse_test.go b/mockgen/parse_test.go index 3c4ba4c..d672927 100644 --- a/mockgen/parse_test.go +++ b/mockgen/parse_test.go @@ -3,6 +3,7 @@ package main import ( "go/parser" "go/token" + "strings" "testing" ) @@ -143,3 +144,88 @@ func TestParseArrayWithConstLength(t *testing.T) { } } } + +func TestParseFile_IncludeOnlyRequested(t *testing.T) { + fs := token.NewFileSet() + file, err := parser.ParseFile(fs, "internal/tests/custom_package_name/greeter/greeter.go", nil, 0) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + p := fileParser{ + fileSet: fs, + imports: make(map[string]importedPackage), + importedInterfaces: newInterfaceCache(), + // include только один интерфейс + includeNamesSet: map[string]struct{}{"InputMaker": {}}, + } + + pkg, err := p.parseFile("", file) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if len(pkg.Interfaces) != 1 || pkg.Interfaces[0].Name != "InputMaker" { + t.Fatalf("Expected only InputMaker, got %v", pkg.Interfaces) + } +} + +// When requested interface is missing, parser should ignore it (no error, no interfaces). +func TestParseFile_IncludeMissing_Ignored(t *testing.T) { + fs := token.NewFileSet() + file, err := parser.ParseFile(fs, "internal/tests/custom_package_name/greeter/greeter.go", nil, 0) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + p := fileParser{ + fileSet: fs, + imports: make(map[string]importedPackage), + importedInterfaces: newInterfaceCache(), + includeNamesSet: map[string]struct{}{"DoesNotExist": {}}, + } + + pkg, err := p.parseFile("", file) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if len(pkg.Interfaces) != 0 { + t.Fatalf("Expected no interfaces, got %v", pkg.Interfaces) + } +} + +func TestParseFile_IncludeWithDuplicates_Dedupes(t *testing.T) { + fs := token.NewFileSet() + file, err := parser.ParseFile(fs, "internal/tests/custom_package_name/greeter/greeter.go", nil, 0) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + // Эмулируем «случайно указали дубликаты» как это делает sourceMode (через позиционные аргументы) + args := []string{"InputMaker", "InputMaker"} // дубликаты + include := make(map[string]struct{}) + for _, a := range args { + for _, name := range strings.Split(a, ",") { + name = strings.TrimSpace(name) + if name != "" { + include[name] = struct{}{} + } + } + } + + p := fileParser{ + fileSet: fs, + imports: make(map[string]importedPackage), + importedInterfaces: newInterfaceCache(), + includeNamesSet: include, + } + + pkg, err := p.parseFile("", file) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if len(pkg.Interfaces) != 1 || pkg.Interfaces[0].Name != "InputMaker" { + t.Fatalf("Expected only InputMaker once, got %v", pkg.Interfaces) + } +}