Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mockgen/internal/tests/typed/bugreport.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package typed

//go:generate mockgen -typed -aux_files faux=faux/faux.go -destination bugreport_mock.go -package typed -source=bugreport.go Example
//go:generate mockgen -typed -aux_files faux=faux/faux.go -destination bugreport_mock.go -package typed -source=bugreport.go Source

import (
"log"
Expand Down
2 changes: 1 addition & 1 deletion mockgen/internal/tests/typed/bugreport_mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

32 changes: 30 additions & 2 deletions mockgen/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package main

import (
"errors"
"flag"
"fmt"
"go/ast"
"go/build"
Expand Down Expand Up @@ -61,6 +62,18 @@ func sourceMode(source string) (*model.Package, error) {
srcDir: srcDir,
}

// positional interface names -> include set
if flag.NArg() > 1 {
return nil, errors.New("-source mode accepts at most one argument")
}
if flag.NArg() == 1 {
ifaces := strings.Split(flag.Arg(0), ",")
p.includeNamesSet = make(map[string]struct{}, len(ifaces))
for _, name := range ifaces {
p.includeNamesSet[name] = struct{}{}
}
}

// Handle -imports.
dotImports := make(map[string]bool)
if *imports != "" {
Expand Down Expand Up @@ -92,6 +105,7 @@ func sourceMode(source string) (*model.Package, error) {
for pkgPath := range dotImports {
pkg.DotImports = append(pkg.DotImports, pkgPath)
}

return pkg, nil
}

Expand Down Expand Up @@ -168,6 +182,7 @@ type fileParser struct {
auxInterfaces *interfaceCache
srcDir string
excludeNamesSet map[string]struct{}
includeNamesSet map[string]struct{} // empty to include all
}

func (p *fileParser) errorf(pos token.Pos, format string, args ...any) error {
Expand Down Expand Up @@ -228,18 +243,31 @@ func (p *fileParser) parseFile(importPath string, file *ast.File) (*model.Packag

var is []*model.Interface
for ni := range iterInterfaces(file) {
if _, ok := p.excludeNamesSet[ni.name.String()]; ok {
name := ni.name.String()

if _, ok := p.excludeNamesSet[name]; ok {
continue
}
i, err := p.parseInterface(ni.name.String(), importPath, ni)

// All interfaces are included if no filter was specified.
if len(p.includeNamesSet) > 0 {
if _, ok := p.includeNamesSet[name]; !ok {
continue
}
}

i, err := p.parseInterface(name, importPath, ni)
if errors.Is(err, errConstraintInterface) {
continue
}
if err != nil {
return nil, err
}
is = append(is, i)

delete(p.includeNamesSet, name)
}

return &model.Package{
Name: file.Name.String(),
PkgPath: importPath,
Expand Down
86 changes: 86 additions & 0 deletions mockgen/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import (
"go/parser"
"go/token"
"strings"
"testing"
)

Expand Down Expand Up @@ -143,3 +144,88 @@ func TestParseArrayWithConstLength(t *testing.T) {
}
}
}

func TestParseFile_IncludeOnlyRequested(t *testing.T) {
fs := token.NewFileSet()
file, err := parser.ParseFile(fs, "internal/tests/custom_package_name/greeter/greeter.go", nil, 0)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

p := fileParser{
fileSet: fs,
imports: make(map[string]importedPackage),
importedInterfaces: newInterfaceCache(),
// include только один интерфейс
includeNamesSet: map[string]struct{}{"InputMaker": {}},
}

pkg, err := p.parseFile("", file)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

if len(pkg.Interfaces) != 1 || pkg.Interfaces[0].Name != "InputMaker" {
t.Fatalf("Expected only InputMaker, got %v", pkg.Interfaces)
}
}

// When requested interface is missing, parser should ignore it (no error, no interfaces).
func TestParseFile_IncludeMissing_Ignored(t *testing.T) {
fs := token.NewFileSet()
file, err := parser.ParseFile(fs, "internal/tests/custom_package_name/greeter/greeter.go", nil, 0)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

p := fileParser{
fileSet: fs,
imports: make(map[string]importedPackage),
importedInterfaces: newInterfaceCache(),
includeNamesSet: map[string]struct{}{"DoesNotExist": {}},
}

pkg, err := p.parseFile("", file)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if len(pkg.Interfaces) != 0 {
t.Fatalf("Expected no interfaces, got %v", pkg.Interfaces)
}
}

func TestParseFile_IncludeWithDuplicates_Dedupes(t *testing.T) {
fs := token.NewFileSet()
file, err := parser.ParseFile(fs, "internal/tests/custom_package_name/greeter/greeter.go", nil, 0)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

// Эмулируем «случайно указали дубликаты» как это делает sourceMode (через позиционные аргументы)
args := []string{"InputMaker", "InputMaker"} // дубликаты
include := make(map[string]struct{})
for _, a := range args {
for _, name := range strings.Split(a, ",") {
name = strings.TrimSpace(name)
if name != "" {
include[name] = struct{}{}
}
}
}

p := fileParser{
fileSet: fs,
imports: make(map[string]importedPackage),
importedInterfaces: newInterfaceCache(),
includeNamesSet: include,
}

pkg, err := p.parseFile("", file)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

if len(pkg.Interfaces) != 1 || pkg.Interfaces[0].Name != "InputMaker" {
t.Fatalf("Expected only InputMaker once, got %v", pkg.Interfaces)
}
}