7
7
from dataclasses import dataclass , field
8
8
from math import ceil
9
9
import os
10
- import re
10
+ import subprocess
11
+ import sys
11
12
12
13
13
- @dataclass
14
- class TestFile :
15
- """Class to hold test information."""
14
+ class Bucket :
15
+ """Class to hold bucket."""
16
16
17
- path : str
18
- total_tests : int
17
+ def __init__ (
18
+ self ,
19
+ ):
20
+ """Initialize bucket."""
21
+ self .tests = 0
22
+ self ._paths = []
19
23
24
+ def add (self , part : TestFolder | TestFile ):
25
+ """Add tests to bucket."""
26
+ self .tests += part .total_tests
27
+ self ._paths .append (part .path )
20
28
21
- @dataclass
22
- class TestFolder :
23
- """Class to hold test information."""
24
-
25
- path : str
26
- children : list [TestFolder | TestFile ] = field (default_factory = list )
27
-
28
- @property
29
- def total_tests (self ) -> int :
30
- """Return total tests."""
31
- return sum ([test .total_tests for test in self .children ])
32
-
33
- def __repr__ (self ):
34
- """Return representation."""
35
- return f"TestFolder(path='{ self .path } ', total={ self .total_tests } , children={ len (self .children )} )"
36
-
37
-
38
- def count_tests (test_folder : TestFolder ) -> int :
39
- """Count tests in folder."""
40
- max_tests_in_file = 0
41
- for entry in os .listdir (test_folder .path ):
42
- if entry in ("__pycache__" , "__init__.py" , "conftest.py" ):
43
- continue
44
-
45
- entry_path = os .path .join (test_folder .path , entry )
46
- if os .path .isdir (entry_path ):
47
- sub_folder = TestFolder (entry_path )
48
- test_folder .children .append (sub_folder )
49
- max_tests_in_file = max (max_tests_in_file , count_tests (sub_folder ))
50
- elif os .path .isfile (entry_path ) and entry .startswith ("test_" ):
51
- tests = 0
52
- with open (entry_path ) as file :
53
- for line in file :
54
- if re .match (r"^(async\s+)?def\s+test_\w+\(" , line ):
55
- tests += 1
56
- test_folder .children .append (TestFile (entry_path , tests ))
57
- max_tests_in_file = max (max_tests_in_file , tests )
58
-
59
- return max_tests_in_file
29
+ def get_paths_line (self ) -> str :
30
+ """Return paths."""
31
+ return " " .join (self ._paths ) + "\n "
60
32
61
33
62
34
class BucketHolder :
@@ -66,38 +38,119 @@ def __init__(self, tests_per_bucket: int, bucket_count: int) -> None:
66
38
"""Initialize bucket holder."""
67
39
self ._tests_per_bucket = tests_per_bucket
68
40
self ._bucket_count = bucket_count
69
- self ._current_bucket = []
70
- self ._current_tests = 0
71
- self ._buckets : list [ list [ str ]] = [ self . _current_bucket ]
41
+ self ._current_bucket = Bucket ()
42
+ self ._buckets : list [ Bucket ] = [ self . _current_bucket ]
43
+ self ._last_bucket = False
72
44
73
45
def split_tests (self , tests : TestFolder | TestFile ) -> None :
74
46
"""Split tests into buckets."""
75
- if self ._current_tests + tests .total_tests < self ._tests_per_bucket :
76
- self ._current_bucket .append (tests .path )
77
- self ._current_tests += tests .total_tests
47
+ if (
48
+ self ._current_bucket .tests + tests .total_tests < self ._tests_per_bucket
49
+ ) or self ._last_bucket :
50
+ self ._current_bucket .add (tests )
78
51
return
79
52
80
53
if isinstance (tests , TestFolder ):
81
- for test in tests .children :
54
+ for test in tests .children . values () :
82
55
self .split_tests (test )
83
56
return
84
57
85
58
# Create new bucket
86
- self ._current_tests = 0
87
-
88
- # The last bucket is lightly bigger (max the maximum number of tests in a single file)
89
- if len ( self . _buckets ) != self . _bucket_count :
90
- self ._current_bucket = []
59
+ if len ( self ._buckets ) == self . _bucket_count :
60
+ # Last bucket, add all tests to it
61
+ self . _last_bucket = True
62
+ else :
63
+ self ._current_bucket = Bucket ()
91
64
self ._buckets .append (self ._current_bucket )
92
65
93
66
# Add test to new bucket
94
67
self .split_tests (tests )
95
68
96
- def create_ouput_files (self ) -> None :
97
- """Create output files ."""
69
+ def create_ouput_file (self ) -> None :
70
+ """Create output file ."""
98
71
with open ("pytest_buckets.txt" , "w" ) as file :
99
72
for bucket in self ._buckets :
100
- file .write (" " .join (bucket ) + "\n " )
73
+ print (f"Bucket has { bucket .tests } tests" )
74
+ file .write (bucket .get_paths_line ())
75
+
76
+
77
+ @dataclass
78
+ class TestFile :
79
+ """Class to hold number of tests."""
80
+
81
+ path : str
82
+ total_tests : int
83
+
84
+ def __gt__ (self , other ):
85
+ """Return if greater than."""
86
+ return self .total_tests > other .total_tests
87
+
88
+
89
+ @dataclass
90
+ class TestFolder :
91
+ """Class to hold test information."""
92
+
93
+ path : str
94
+ children : dict [str , TestFolder | TestFile ] = field (default_factory = dict )
95
+
96
+ @property
97
+ def total_tests (self ) -> int :
98
+ """Return total tests."""
99
+ return sum ([test .total_tests for test in self .children .values ()])
100
+
101
+ def __repr__ (self ):
102
+ """Return representation."""
103
+ return f"TestFolder(total={ self .total_tests } , children={ len (self .children )} )"
104
+
105
+
106
+ def insert_at_correct_position (
107
+ test_holder : TestFolder , test_path : str , total_tests : int
108
+ ) -> None :
109
+ """Insert test at correct position."""
110
+ current_path = test_holder
111
+ for part in test_path .split ("/" )[1 :]:
112
+ if part .endswith (".py" ):
113
+ current_path .children [part ] = TestFile (test_path , total_tests )
114
+ else :
115
+ current_path = current_path .children .setdefault (
116
+ part , TestFolder (os .path .join (current_path .path , part ))
117
+ )
118
+
119
+
120
+ def collect_tests (path : str ) -> tuple [TestFolder , TestFile ]:
121
+ """Collect all tests."""
122
+ result = subprocess .run (
123
+ ["pytest" , "--collect-only" , "-qq" , "-p" , "no:warnings" , path ],
124
+ check = False ,
125
+ capture_output = True ,
126
+ text = True ,
127
+ )
128
+
129
+ if result .returncode != 0 :
130
+ print ("Failed to collect tests:" )
131
+ print (result .stderr )
132
+ print (result .stdout )
133
+ sys .exit (1 )
134
+
135
+ folder = TestFolder (path .split ("/" )[0 ])
136
+ insert_at_correct_position (folder , path , 0 )
137
+ max_tests_in_file = TestFile ("" , 0 )
138
+
139
+ for line in result .stdout .splitlines ():
140
+ if not line .strip ():
141
+ continue
142
+ parts = [x .strip () for x in line .split (":" )]
143
+ if len (parts ) != 2 :
144
+ print (f"Unexpected line: { line } " )
145
+ sys .exit (1 )
146
+
147
+ path = parts [0 ]
148
+ total_tests = int (parts [1 ])
149
+ max_tests_in_file = max (max_tests_in_file , TestFile (path , total_tests ))
150
+
151
+ insert_at_correct_position (folder , path , total_tests )
152
+
153
+ return (folder , max_tests_in_file )
101
154
102
155
103
156
def main () -> None :
@@ -120,22 +173,23 @@ def check_greater_0(value: str) -> int:
120
173
121
174
arguments = parser .parse_args ()
122
175
123
- tests = TestFolder ("tests" )
124
- max_tests_in_file = count_tests (tests )
125
- print (f"Maximum tests in a single file: { max_tests_in_file } " )
176
+ (tests , max_tests_in_file ) = collect_tests ("tests" )
177
+ print (
178
+ f"Maximum tests in a single file are { max_tests_in_file .total_tests } tests (in { max_tests_in_file .path } )"
179
+ )
126
180
print (f"Total tests: { tests .total_tests } " )
127
181
128
182
tests_per_bucket = ceil (tests .total_tests / arguments .bucket_count )
129
183
print (f"Estimated tests per bucket: { tests_per_bucket } " )
130
184
131
- if max_tests_in_file > tests_per_bucket :
185
+ if max_tests_in_file . total_tests > tests_per_bucket :
132
186
raise ValueError (
133
187
f"There are more tests in a single file ({ max_tests_in_file } ) than tests per bucket ({ tests_per_bucket } )"
134
188
)
135
189
136
190
bucket_holder = BucketHolder (tests_per_bucket , arguments .bucket_count )
137
191
bucket_holder .split_tests (tests )
138
- bucket_holder .create_ouput_files ()
192
+ bucket_holder .create_ouput_file ()
139
193
140
194
141
195
if __name__ == "__main__" :
0 commit comments