@@ -42,7 +42,8 @@ func run(pass *analysis.Pass) (interface{}, error) {
4242 var rangeNode ast.Node
4343
4444 // Check runs for test functions only
45- if ! isTestFunction (funcDecl ) {
45+ isTest , testVar := isTestFunction (funcDecl )
46+ if ! isTest {
4647 return
4748 }
4849
@@ -53,16 +54,19 @@ func run(pass *analysis.Pass) (interface{}, error) {
5354 ast .Inspect (v , func (n ast.Node ) bool {
5455 // Check if the test method is calling t.parallel
5556 if ! funcHasParallelMethod {
56- funcHasParallelMethod = methodParallelIsCalledInTestFunction (n )
57+ funcHasParallelMethod = methodParallelIsCalledInTestFunction (n , testVar )
5758 }
5859
5960 // Check if the t.Run within the test function is calling t.parallel
60- if methodRunIsCalledInTestFunction (n ) {
61+ if methodRunIsCalledInTestFunction (n , testVar ) {
62+ // n is a call to t.Run; find out the name of the subtest's *testing.T parameter.
63+ innerTestVar := getRunCallbackParameterName (n )
64+
6165 hasParallel := false
6266 numberOfTestRun ++
6367 ast .Inspect (v , func (p ast.Node ) bool {
6468 if ! hasParallel {
65- hasParallel = methodParallelIsCalledInTestFunction (p )
69+ hasParallel = methodParallelIsCalledInTestFunction (p , innerTestVar )
6670 }
6771 return true
6872 })
@@ -81,12 +85,15 @@ func run(pass *analysis.Pass) (interface{}, error) {
8185 // nolint: gocritic
8286 switch r := n .(type ) {
8387 case * ast.ExprStmt :
84- if methodRunIsCalledInRangeStatement (r .X ) {
88+ if methodRunIsCalledInRangeStatement (r .X , testVar ) {
89+ // r.X is a call to t.Run; find out the name of the subtest's *testing.T parameter.
90+ innerTestVar := getRunCallbackParameterName (r .X )
91+
8592 rangeStatementOverTestCasesExists = true
8693 testRunLoopIdentifier = methodRunFirstArgumentObjectName (r .X )
8794
8895 if ! rangeStatementHasParallelMethod {
89- rangeStatementHasParallelMethod = methodParallelIsCalledInMethodRun (r .X )
96+ rangeStatementHasParallelMethod = methodParallelIsCalledInMethodRun (r .X , innerTestVar )
9097 }
9198 }
9299 }
@@ -165,7 +172,7 @@ func getLeftAndRightIdentifier(s ast.Stmt) (string, string) {
165172 return leftIdentifier , rightIdentifier
166173}
167174
168- func methodParallelIsCalledInMethodRun (node ast.Node ) bool {
175+ func methodParallelIsCalledInMethodRun (node ast.Node , testVar string ) bool {
169176 var methodParallelCalled bool
170177 // nolint: gocritic
171178 switch callExp := node .(type ) {
@@ -174,7 +181,7 @@ func methodParallelIsCalledInMethodRun(node ast.Node) bool {
174181 if ! methodParallelCalled {
175182 ast .Inspect (arg , func (n ast.Node ) bool {
176183 if ! methodParallelCalled {
177- methodParallelCalled = methodParallelIsCalledInRunMethod (n )
184+ methodParallelCalled = methodParallelIsCalledInRunMethod (n , testVar )
178185 return true
179186 }
180187 return false
@@ -185,32 +192,61 @@ func methodParallelIsCalledInMethodRun(node ast.Node) bool {
185192 return methodParallelCalled
186193}
187194
188- func methodParallelIsCalledInRunMethod (node ast.Node ) bool {
189- return exprCallHasMethod (node , "Parallel" )
195+ func methodParallelIsCalledInRunMethod (node ast.Node , testVar string ) bool {
196+ return exprCallHasMethod (node , testVar , "Parallel" )
190197}
191198
192- func methodParallelIsCalledInTestFunction (node ast.Node ) bool {
193- return exprCallHasMethod (node , "Parallel" )
199+ func methodParallelIsCalledInTestFunction (node ast.Node , testVar string ) bool {
200+ return exprCallHasMethod (node , testVar , "Parallel" )
194201}
195202
196- func methodRunIsCalledInRangeStatement (node ast.Node ) bool {
197- return exprCallHasMethod (node , "Run" )
203+ func methodRunIsCalledInRangeStatement (node ast.Node , testVar string ) bool {
204+ return exprCallHasMethod (node , testVar , "Run" )
198205}
199206
200- func methodRunIsCalledInTestFunction (node ast.Node ) bool {
201- return exprCallHasMethod (node , "Run" )
207+ func methodRunIsCalledInTestFunction (node ast.Node , testVar string ) bool {
208+ return exprCallHasMethod (node , testVar , "Run" )
202209}
203- func exprCallHasMethod (node ast.Node , methodName string ) bool {
210+ func exprCallHasMethod (node ast.Node , receiverName , methodName string ) bool {
204211 // nolint: gocritic
205212 switch n := node .(type ) {
206213 case * ast.CallExpr :
207214 if fun , ok := n .Fun .(* ast.SelectorExpr ); ok {
208- return fun .Sel .Name == methodName
215+ if receiver , ok := fun .X .(* ast.Ident ); ok {
216+ return receiver .Name == receiverName && fun .Sel .Name == methodName
217+ }
209218 }
210219 }
211220 return false
212221}
213222
223+ // In an expression of the form t.Run(x, func(q *testing.T) {...}), return the
224+ // value "q". In _most_ code, the name is probably t, but we shouldn't just
225+ // assume.
226+ func getRunCallbackParameterName (node ast.Node ) string {
227+ if n , ok := node .(* ast.CallExpr ); ok {
228+ if len (n .Args ) < 2 {
229+ // We want argument #2, but this call doesn't have two
230+ // arguments. Maybe it's not really t.Run.
231+ return ""
232+ }
233+ funcArg := n .Args [1 ]
234+ if fun , ok := funcArg .(* ast.FuncLit ); ok {
235+ if len (fun .Type .Params .List ) < 1 {
236+ // Subtest function doesn't have any parameters.
237+ return ""
238+ }
239+ firstArg := fun .Type .Params .List [0 ]
240+ // We'll assume firstArg.Type is *testing.T.
241+ if len (firstArg .Names ) < 1 {
242+ return ""
243+ }
244+ return firstArg .Names [0 ].Name
245+ }
246+ }
247+ return ""
248+ }
249+
214250// Gets the object name `tc` from method t.Run(tc.Foo, func(t *testing.T)
215251func methodRunFirstArgumentObjectName (node ast.Node ) string {
216252 // nolint: gocritic
@@ -227,30 +263,31 @@ func methodRunFirstArgumentObjectName(node ast.Node) string {
227263 return ""
228264}
229265
230- // Checks if the function has the param type *testing.T)
231- func isTestFunction (funcDecl * ast.FuncDecl ) bool {
266+ // Checks if the function has the param type *testing.T; if it does, then the
267+ // parameter name is returned, too.
268+ func isTestFunction (funcDecl * ast.FuncDecl ) (bool , string ) {
232269 testMethodPackageType := "testing"
233270 testMethodStruct := "T"
234271 testPrefix := "Test"
235272
236273 if ! strings .HasPrefix (funcDecl .Name .Name , testPrefix ) {
237- return false
274+ return false , ""
238275 }
239276
240277 if funcDecl .Type .Params != nil && len (funcDecl .Type .Params .List ) != 1 {
241- return false
278+ return false , ""
242279 }
243280
244281 param := funcDecl .Type .Params .List [0 ]
245282 if starExp , ok := param .Type .(* ast.StarExpr ); ok {
246283 if selectExpr , ok := starExp .X .(* ast.SelectorExpr ); ok {
247284 if selectExpr .Sel .Name == testMethodStruct {
248285 if s , ok := selectExpr .X .(* ast.Ident ); ok {
249- return s .Name == testMethodPackageType
286+ return s .Name == testMethodPackageType , param . Names [ 0 ]. Name
250287 }
251288 }
252289 }
253290 }
254291
255- return false
292+ return false , ""
256293}
0 commit comments