Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions mockgen/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package main

import (
"errors"
"flag"
"fmt"
"go/ast"
"go/build"
Expand Down Expand Up @@ -92,6 +93,20 @@ func sourceMode(source string) (*model.Package, error) {
for pkgPath := range dotImports {
pkg.DotImports = append(pkg.DotImports, pkgPath)
}

// Get positional arguments after the flags
ifaces := flag.Args()

// If there are interfaces provided as positional arguments, filter them
if len(ifaces) > 0 {
if pkg.Interfaces, err = filterInterfaces(pkg.Interfaces, ifaces); err != nil {
log.Fatalf("Filtering interfaces failed: %v", err)
}
} else {
// No interfaces provided, process all interfaces for backward compatibility
log.Printf("No interfaces specified, processing all interfaces")
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, this is still doing the thing @JacobOaks's comment advised against:

Instead of parsing and then dropping interfaces that aren't specified, can we simply not parse ones that aren't requested? This is similar to how the exclusion flag already works and would avoid some wasted computation.


return pkg, nil
}

Expand Down Expand Up @@ -802,4 +817,36 @@ func packageNameOfDir(srcDir string) (string, error) {
return packageImport, nil
}

func filterInterfaces(all []*model.Interface, requested []string) ([]*model.Interface, error) {
// If no interfaces are requested, return all interfaces
if len(requested) == 0 {
return all, nil
}

requestedIfaces := make(map[string]struct{})
for _, iface := range requested {
requestedIfaces[iface] = struct{}{}
}

result := make([]*model.Interface, 0, len(requestedIfaces))
for _, iface := range all {
// Only add interfaces that are requested
if _, ok := requestedIfaces[iface.Name]; ok {
result = append(result, iface)
delete(requestedIfaces, iface.Name) // Remove matched iface from requested
}
}

// If any requested interfaces were not found, return an error
if len(requestedIfaces) > 0 {
var missing []string
for iface := range requestedIfaces {
missing = append(missing, iface)
}
return nil, fmt.Errorf("missing interfaces: %s", strings.Join(missing, ", "))
}

return result, nil
}

var errOutsideGoPath = errors.New("source directory is outside GOPATH")
127 changes: 127 additions & 0 deletions mockgen/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ package main
import (
"go/parser"
"go/token"
"reflect"
"testing"

"go.uber.org/mock/mockgen/model"
)

func TestFileParser_ParseFile(t *testing.T) {
Expand Down Expand Up @@ -143,3 +146,127 @@ func TestParseArrayWithConstLength(t *testing.T) {
}
}
}

func Test_filterInterfaces(t *testing.T) {
type args struct {
all []*model.Interface
requested []string
}
tests := []struct {
name string
args args
want []*model.Interface
wantErr bool
}{
{
name: "no filter (returns all interfaces)",
args: args{
all: []*model.Interface{
{
Name: "Foo",
},
{
Name: "Bar",
},
},
requested: []string{},
},
want: []*model.Interface{
{
Name: "Foo",
},
{
Name: "Bar",
},
},
wantErr: false,
},
{
name: "filter by Foo",
args: args{
all: []*model.Interface{
{
Name: "Foo",
},
{
Name: "Bar",
},
},
requested: []string{"Foo"},
},
want: []*model.Interface{
{
Name: "Foo",
},
},
wantErr: false,
},
{
name: "filter by Foo and Bar",
args: args{
all: []*model.Interface{
{
Name: "Foo",
},
{
Name: "Bar",
},
},
requested: []string{"Foo", "Bar"},
},
want: []*model.Interface{
{
Name: "Foo",
},
{
Name: "Bar",
},
},
wantErr: false,
},
{
name: "incorrect filter by Foo and Baz",
args: args{
all: []*model.Interface{
{
Name: "Foo",
},
{
Name: "Bar",
},
},
requested: []string{"Foo", "Baz"},
},
want: nil,
wantErr: true,
},
{
name: "missing interface (Baz not found)",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we also add a test case for accidentally specified duplicate requested interfaces?

args: args{
all: []*model.Interface{
{
Name: "Foo",
},
{
Name: "Bar",
},
},
requested: []string{"Baz"},
},
want: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := filterInterfaces(tt.args.all, tt.args.requested)
if (err != nil) != tt.wantErr {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of wantErr being a bool, could we make this a string and compare the error contents?

t.Errorf("filterInterfaces() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("filterInterfaces() got = %v, want %v", got, tt.want)
}
})
}
}