diff --git a/pkg/paralleltest/paralleltest.go b/pkg/paralleltest/paralleltest.go index a6aeaaf4..02af29d3 100644 --- a/pkg/paralleltest/paralleltest.go +++ b/pkg/paralleltest/paralleltest.go @@ -89,6 +89,7 @@ func (a *parallelAnalyzer) analyzeTestRun(pass *analysis.Pass, n ast.Node, testV return true }) } else if ident, ok := callExpr.Args[1].(*ast.Ident); ok { + // Case 2: Direct function identifier: t.Run("name", myFunc) foundFunc := false for _, file := range pass.Files { for _, decl := range file.Decls { @@ -109,6 +110,9 @@ func (a *parallelAnalyzer) analyzeTestRun(pass *analysis.Pass, n ast.Node, testV if !foundFunc { analysis.hasParallel = false } + } else if builderCall, ok := callExpr.Args[1].(*ast.CallExpr); ok { + // Case 3: Function call that returns a function: t.Run("name", builder()) + analysis.hasParallel = a.checkBuilderFunctionForParallel(pass, builderCall) } } @@ -230,6 +234,86 @@ func (a *parallelAnalyzer) analyzeTestFunction(pass *analysis.Pass, funcDecl *as } } +// checkBuilderFunctionForParallel analyzes a function call that returns a test function +// to see if the returned function contains t.Parallel() +func (a *parallelAnalyzer) checkBuilderFunctionForParallel(pass *analysis.Pass, builderCall *ast.CallExpr) bool { + // Get the name of the builder function being called + var builderFuncName string + switch fun := builderCall.Fun.(type) { + case *ast.Ident: + builderFuncName = fun.Name + case *ast.SelectorExpr: + // Handle method calls like obj.Builder() + builderFuncName = fun.Sel.Name + default: + return false + } + + if builderFuncName == "" { + return false + } + + // Find the builder function declaration + for _, file := range pass.Files { + for _, decl := range file.Decls { + funcDecl, ok := decl.(*ast.FuncDecl) + if !ok || funcDecl.Name.Name != builderFuncName { + continue + } + + // Found the builder function, analyze it and return immediately + hasParallel := false + ast.Inspect(funcDecl, func(n ast.Node) bool { + if hasParallel { + return false + } + + // Look for return statements + returnStmt, ok := n.(*ast.ReturnStmt) + if !ok || len(returnStmt.Results) == 0 { + return true + } + + // Check if the return value is a function literal + for _, result := range returnStmt.Results { + if funcLit, ok := result.(*ast.FuncLit); ok { + // Get the parameter name from the returned function + var paramName string + if funcLit.Type != nil && funcLit.Type.Params != nil && len(funcLit.Type.Params.List) > 0 { + param := funcLit.Type.Params.List[0] + if len(param.Names) > 0 { + paramName = param.Names[0].Name + } + } + + // Inspect the returned function for t.Parallel() + if paramName != "" { + ast.Inspect(funcLit, func(p ast.Node) bool { + if methodParallelIsCalledInTestFunction(p, paramName) { + hasParallel = true + return false + } + return true + }) + + // Exit immediately if we found t.Parallel() + if hasParallel { + return false + } + } + } + } + return !hasParallel // Stop inspection if we found t.Parallel() + }) + + // Return immediately after processing the matching function + return hasParallel + } + } + + return false +} + func (a *parallelAnalyzer) run(pass *analysis.Pass) (interface{}, error) { inspector := inspector.New(pass.Files) diff --git a/pkg/paralleltest/testdata/src/t/t_test.go b/pkg/paralleltest/testdata/src/t/t_test.go index 0454659d..37809432 100644 --- a/pkg/paralleltest/testdata/src/t/t_test.go +++ b/pkg/paralleltest/testdata/src/t/t_test.go @@ -317,7 +317,7 @@ func TestRangeHelperWithDifferentParamNames(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() - t.Run("sub1", rangeHelperWithCustomParam) // want "Function TestRangeHelperWithDifferentParamNames missing the call to method parallel in the test run" + t.Run("sub1", rangeHelperWithCustomParam) // want "Function TestRangeHelperWithDifferentParamNames missing the call to method parallel in the test run" t.Run("sub2", rangeHelperWithAnotherParam) // want "Function TestRangeHelperWithDifferentParamNames missing the call to method parallel in the test run" }) } @@ -330,3 +330,29 @@ func rangeHelperWithCustomParam(testT *testing.T) { func rangeHelperWithAnotherParam(t *testing.T) { fmt.Println("range another") } + +// Test cases with builder functions that return test functions +func TestBuilderFunctionReturningTestFunc(t *testing.T) { + t.Parallel() + t.Run("1", builderWithParallel()) + t.Run("2", builderWithParallel()) +} + +func builderWithParallel() func(t *testing.T) { + return func(t *testing.T) { + t.Parallel() + fmt.Println("test from builder") + } +} + +func TestBuilderFunctionMissingParallel(t *testing.T) { + t.Parallel() + t.Run("1", builderWithoutParallel()) // want "Function TestBuilderFunctionMissingParallel missing the call to method parallel in the test run" + t.Run("2", builderWithoutParallel()) // want "Function TestBuilderFunctionMissingParallel missing the call to method parallel in the test run" +} + +func builderWithoutParallel() func(t *testing.T) { + return func(t *testing.T) { + fmt.Println("test from builder without parallel") + } +}