diff --git a/pkg/paralleltest/paralleltest.go b/pkg/paralleltest/paralleltest.go index 31f6f294..0e359f76 100644 --- a/pkg/paralleltest/paralleltest.go +++ b/pkg/paralleltest/paralleltest.go @@ -42,7 +42,8 @@ func run(pass *analysis.Pass) (interface{}, error) { var rangeNode ast.Node // Check runs for test functions only - if !isTestFunction(funcDecl) { + isTest, testVar := isTestFunction(funcDecl) + if !isTest { return } @@ -53,16 +54,19 @@ func run(pass *analysis.Pass) (interface{}, error) { ast.Inspect(v, func(n ast.Node) bool { // Check if the test method is calling t.parallel if !funcHasParallelMethod { - funcHasParallelMethod = methodParallelIsCalledInTestFunction(n) + funcHasParallelMethod = methodParallelIsCalledInTestFunction(n, testVar) } // Check if the t.Run within the test function is calling t.parallel - if methodRunIsCalledInTestFunction(n) { + if methodRunIsCalledInTestFunction(n, testVar) { + // n is a call to t.Run; find out the name of the subtest's *testing.T parameter. + innerTestVar := getRunCallbackParameterName(n) + hasParallel := false numberOfTestRun++ ast.Inspect(v, func(p ast.Node) bool { if !hasParallel { - hasParallel = methodParallelIsCalledInTestFunction(p) + hasParallel = methodParallelIsCalledInTestFunction(p, innerTestVar) } return true }) @@ -81,12 +85,15 @@ func run(pass *analysis.Pass) (interface{}, error) { // nolint: gocritic switch r := n.(type) { case *ast.ExprStmt: - if methodRunIsCalledInRangeStatement(r.X) { + if methodRunIsCalledInRangeStatement(r.X, testVar) { + // r.X is a call to t.Run; find out the name of the subtest's *testing.T parameter. + innerTestVar := getRunCallbackParameterName(r.X) + rangeStatementOverTestCasesExists = true testRunLoopIdentifier = methodRunFirstArgumentObjectName(r.X) if !rangeStatementHasParallelMethod { - rangeStatementHasParallelMethod = methodParallelIsCalledInMethodRun(r.X) + rangeStatementHasParallelMethod = methodParallelIsCalledInMethodRun(r.X, innerTestVar) } } } @@ -165,7 +172,7 @@ func getLeftAndRightIdentifier(s ast.Stmt) (string, string) { return leftIdentifier, rightIdentifier } -func methodParallelIsCalledInMethodRun(node ast.Node) bool { +func methodParallelIsCalledInMethodRun(node ast.Node, testVar string) bool { var methodParallelCalled bool // nolint: gocritic switch callExp := node.(type) { @@ -174,7 +181,7 @@ func methodParallelIsCalledInMethodRun(node ast.Node) bool { if !methodParallelCalled { ast.Inspect(arg, func(n ast.Node) bool { if !methodParallelCalled { - methodParallelCalled = methodParallelIsCalledInRunMethod(n) + methodParallelCalled = methodParallelIsCalledInRunMethod(n, testVar) return true } return false @@ -185,32 +192,61 @@ func methodParallelIsCalledInMethodRun(node ast.Node) bool { return methodParallelCalled } -func methodParallelIsCalledInRunMethod(node ast.Node) bool { - return exprCallHasMethod(node, "Parallel") +func methodParallelIsCalledInRunMethod(node ast.Node, testVar string) bool { + return exprCallHasMethod(node, testVar, "Parallel") } -func methodParallelIsCalledInTestFunction(node ast.Node) bool { - return exprCallHasMethod(node, "Parallel") +func methodParallelIsCalledInTestFunction(node ast.Node, testVar string) bool { + return exprCallHasMethod(node, testVar, "Parallel") } -func methodRunIsCalledInRangeStatement(node ast.Node) bool { - return exprCallHasMethod(node, "Run") +func methodRunIsCalledInRangeStatement(node ast.Node, testVar string) bool { + return exprCallHasMethod(node, testVar, "Run") } -func methodRunIsCalledInTestFunction(node ast.Node) bool { - return exprCallHasMethod(node, "Run") +func methodRunIsCalledInTestFunction(node ast.Node, testVar string) bool { + return exprCallHasMethod(node, testVar, "Run") } -func exprCallHasMethod(node ast.Node, methodName string) bool { +func exprCallHasMethod(node ast.Node, receiverName, methodName string) bool { // nolint: gocritic switch n := node.(type) { case *ast.CallExpr: if fun, ok := n.Fun.(*ast.SelectorExpr); ok { - return fun.Sel.Name == methodName + if receiver, ok := fun.X.(*ast.Ident); ok { + return receiver.Name == receiverName && fun.Sel.Name == methodName + } } } return false } +// In an expression of the form t.Run(x, func(q *testing.T) {...}), return the +// value "q". In _most_ code, the name is probably t, but we shouldn't just +// assume. +func getRunCallbackParameterName(node ast.Node) string { + if n, ok := node.(*ast.CallExpr); ok { + if len(n.Args) < 2 { + // We want argument #2, but this call doesn't have two + // arguments. Maybe it's not really t.Run. + return "" + } + funcArg := n.Args[1] + if fun, ok := funcArg.(*ast.FuncLit); ok { + if len(fun.Type.Params.List) < 1 { + // Subtest function doesn't have any parameters. + return "" + } + firstArg := fun.Type.Params.List[0] + // We'll assume firstArg.Type is *testing.T. + if len(firstArg.Names) < 1 { + return "" + } + return firstArg.Names[0].Name + } + } + return "" +} + // Gets the object name `tc` from method t.Run(tc.Foo, func(t *testing.T) func methodRunFirstArgumentObjectName(node ast.Node) string { // nolint: gocritic @@ -227,18 +263,19 @@ func methodRunFirstArgumentObjectName(node ast.Node) string { return "" } -// Checks if the function has the param type *testing.T) -func isTestFunction(funcDecl *ast.FuncDecl) bool { +// Checks if the function has the param type *testing.T; if it does, then the +// parameter name is returned, too. +func isTestFunction(funcDecl *ast.FuncDecl) (bool, string) { testMethodPackageType := "testing" testMethodStruct := "T" testPrefix := "Test" if !strings.HasPrefix(funcDecl.Name.Name, testPrefix) { - return false + return false, "" } if funcDecl.Type.Params != nil && len(funcDecl.Type.Params.List) != 1 { - return false + return false, "" } param := funcDecl.Type.Params.List[0] @@ -246,11 +283,11 @@ func isTestFunction(funcDecl *ast.FuncDecl) bool { if selectExpr, ok := starExp.X.(*ast.SelectorExpr); ok { if selectExpr.Sel.Name == testMethodStruct { if s, ok := selectExpr.X.(*ast.Ident); ok { - return s.Name == testMethodPackageType + return s.Name == testMethodPackageType, param.Names[0].Name } } } } - return false + return false, "" } diff --git a/pkg/paralleltest/testdata/src/t/t_test.go b/pkg/paralleltest/testdata/src/t/t_test.go index 4b211b23..5cd10351 100644 --- a/pkg/paralleltest/testdata/src/t/t_test.go +++ b/pkg/paralleltest/testdata/src/t/t_test.go @@ -18,8 +18,8 @@ func TestFunctionSuccessfulRangeTest(t *testing.T) { }{{name: "foo"}} for _, tc := range testCases { tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() + t.Run(tc.name, func(x *testing.T) { + x.Parallel() fmt.Println(tc.name) }) } @@ -129,11 +129,35 @@ func TestFunctionFirstOneTestRunMissingCallToParallel(t *testing.T) { func TestFunctionSecondOneTestRunMissingCallToParallel(t *testing.T) { t.Parallel() - t.Run("1", func(t *testing.T) { - t.Parallel() + t.Run("1", func(x *testing.T) { + x.Parallel() fmt.Println("1") }) t.Run("2", func(t *testing.T) { // want "Function TestFunctionSecondOneTestRunMissingCallToParallel has missing the call to method parallel in the test run" fmt.Println("2") }) } + +type notATest int + +func (notATest) Run(args ...interface{}) {} +func (notATest) Parallel() {} + +func TestFunctionWithRunLookalike(t *testing.T) { + t.Parallel() + + var other notATest + // These aren't t.Run, so they shouldn't be flagged as Run invocations missing calls to Parallel. + other.Run(1, 1) + other.Run(2, 2) +} + +func TestFunctionWithParallelLookalike(t *testing.T) { // want "Function TestFunctionWithParallelLookalike missing the call to method parallel" + var other notATest + // This isn't t.Parallel, so it doesn't qualify as a call to Parallel. + other.Parallel() +} + +func TestFunctionWithOtherTestingVar(q *testing.T) { + q.Parallel() +}