diff --git a/README.md b/README.md index 950694e1..2a6f0001 100644 --- a/README.md +++ b/README.md @@ -51,10 +51,15 @@ that uses reflection to understand interfaces. It is enabled by passing two non-flag arguments: an import path, and a comma-separated list of symbols. +You can use "." to refer to the current path's package. + Example: ```bash mockgen database/sql/driver Conn,Driver + +# Convenient for `go:generate`. +mockgen . Conn,Driver ``` The `mockgen` command is used to generate source code for a mock diff --git a/mockgen/mockgen.go b/mockgen/mockgen.go index 1f14aa7c..b13885d9 100644 --- a/mockgen/mockgen.go +++ b/mockgen/mockgen.go @@ -69,6 +69,7 @@ func main() { var pkg *model.Package var err error + var packageName string if *source != "" { pkg, err = sourceMode(*source) } else { @@ -76,7 +77,18 @@ func main() { usage() log.Fatal("Expected exactly two arguments") } - pkg, err = reflectMode(flag.Arg(0), strings.Split(flag.Arg(1), ",")) + packageName = flag.Arg(0) + if packageName == "." { + dir, err := os.Getwd() + if err != nil { + log.Fatalf("Get current directory failed: %v", err) + } + packageName, err = packageNameOfDir(dir) + if err != nil { + log.Fatalf("Parse package name failed: %v", err) + } + } + pkg, err = reflectMode(packageName, strings.Split(flag.Arg(1), ",")) } if err != nil { log.Fatalf("Loading input failed: %v", err) @@ -130,7 +142,7 @@ func main() { if *source != "" { g.filename = *source } else { - g.srcPackage = flag.Arg(0) + g.srcPackage = packageName g.srcInterfaces = flag.Arg(1) } diff --git a/mockgen/parse.go b/mockgen/parse.go index 2fdda65e..d88f3c95 100644 --- a/mockgen/parse.go +++ b/mockgen/parse.go @@ -24,6 +24,7 @@ import ( "go/build" "go/parser" "go/token" + "io/ioutil" "log" "path" "path/filepath" @@ -48,19 +49,10 @@ func sourceMode(source string) (*model.Package, error) { return nil, fmt.Errorf("failed getting source directory: %v", err) } - cfg := &packages.Config{Mode: packages.LoadFiles, Tests: true, Dir: srcDir} - pkgs, err := packages.Load(cfg, "file="+source) + packageImport, err := parsePackageImport(source, srcDir) if err != nil { return nil, err } - if packages.PrintErrors(pkgs) > 0 || len(pkgs) == 0 { - return nil, errors.New("loading package failed") - } - - packageImport := pkgs[0].PkgPath - - // It is illegal to import a _test package. - packageImport = strings.TrimSuffix(packageImport, "_test") fs := token.NewFileSet() file, err := parser.ParseFile(fs, source, nil, 0) @@ -519,3 +511,46 @@ func isVariadic(f *ast.FuncType) bool { _, ok := f.Params.List[nargs-1].Type.(*ast.Ellipsis) return ok } + +// packageNameOfDir get package import path via dir +func packageNameOfDir(srcDir string) (string, error) { + files, err := ioutil.ReadDir(srcDir) + if err != nil { + log.Fatal(err) + } + + var goFilePath string + for _, file := range files { + if !file.IsDir() && strings.HasSuffix(file.Name(), ".go") { + goFilePath = file.Name() + break + } + } + if goFilePath == "" { + return "", fmt.Errorf("go source file not found %s", srcDir) + } + + packageImport, err := parsePackageImport(goFilePath, srcDir) + if err != nil { + return "", err + } + return packageImport, nil +} + +// parseImportPackage get package import path via source file +func parsePackageImport(source, srcDir string) (string, error) { + cfg := &packages.Config{Mode: packages.LoadFiles, Tests: true, Dir: srcDir} + pkgs, err := packages.Load(cfg, "file="+source) + if err != nil { + return "", err + } + if packages.PrintErrors(pkgs) > 0 || len(pkgs) == 0 { + return "", errors.New("loading package failed") + } + + packageImport := pkgs[0].PkgPath + + // It is illegal to import a _test package. + packageImport = strings.TrimSuffix(packageImport, "_test") + return packageImport, nil +}