Skip to content

Commit 787040e

Browse files
committed
Clean up
1 parent a5757ef commit 787040e

File tree

4 files changed

+126
-259
lines changed

4 files changed

+126
-259
lines changed

.github/workflows/ci.yaml

+4-4
Original file line numberDiff line numberDiff line change
@@ -706,10 +706,10 @@ jobs:
706706
key: >-
707707
${{ runner.os }}-${{ steps.python.outputs.python-version }}-${{
708708
needs.info.outputs.python_cache_key }}
709-
- name: Run split_tests.py
709+
- name: Run split.py
710710
run: |
711711
. venv/bin/activate
712-
python -m script.split_tests_pytest ${{ needs.info.outputs.test_group_count }}
712+
python -m script.split_tests ${{ needs.info.outputs.test_group_count }}
713713
- name: Upload pytest_buckets
714714
uses: actions/[email protected]
715715
with:
@@ -1178,7 +1178,7 @@ jobs:
11781178
./script/check_dirty
11791179
11801180
coverage-partial:
1181-
name: Upload test coverage to Codecov
1181+
name: Upload test coverage to Codecov (partial suite)
11821182
if: needs.info.outputs.skip_coverage != 'true'
11831183
runs-on: ubuntu-22.04
11841184
needs:
@@ -1215,7 +1215,7 @@ jobs:
12151215
attempt_delay: 30000
12161216

12171217
coverage-full:
1218-
name: Upload test coverage to Codecov
1218+
name: Upload test coverage to Codecov (full suite)
12191219
if: needs.info.outputs.skip_coverage != 'true'
12201220
runs-on: ubuntu-22.04
12211221
needs:

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,6 @@ tmp_cache
132132

133133
# python-language-server / Rope
134134
.ropeproject
135+
136+
# Will be created from script/split_tests.py
137+
pytest_buckets.txt

script/split_tests.py

+119-65
Original file line numberDiff line numberDiff line change
@@ -7,56 +7,28 @@
77
from dataclasses import dataclass, field
88
from math import ceil
99
import os
10-
import re
10+
import subprocess
11+
import sys
1112

1213

13-
@dataclass
14-
class TestFile:
15-
"""Class to hold test information."""
14+
class Bucket:
15+
"""Class to hold bucket."""
1616

17-
path: str
18-
total_tests: int
17+
def __init__(
18+
self,
19+
):
20+
"""Initialize bucket."""
21+
self.tests = 0
22+
self._paths = []
1923

24+
def add(self, part: TestFolder | TestFile):
25+
"""Add tests to bucket."""
26+
self.tests += part.total_tests
27+
self._paths.append(part.path)
2028

21-
@dataclass
22-
class TestFolder:
23-
"""Class to hold test information."""
24-
25-
path: str
26-
children: list[TestFolder | TestFile] = field(default_factory=list)
27-
28-
@property
29-
def total_tests(self) -> int:
30-
"""Return total tests."""
31-
return sum([test.total_tests for test in self.children])
32-
33-
def __repr__(self):
34-
"""Return representation."""
35-
return f"TestFolder(path='{self.path}', total={self.total_tests}, children={len(self.children)})"
36-
37-
38-
def count_tests(test_folder: TestFolder) -> int:
39-
"""Count tests in folder."""
40-
max_tests_in_file = 0
41-
for entry in os.listdir(test_folder.path):
42-
if entry in ("__pycache__", "__init__.py", "conftest.py"):
43-
continue
44-
45-
entry_path = os.path.join(test_folder.path, entry)
46-
if os.path.isdir(entry_path):
47-
sub_folder = TestFolder(entry_path)
48-
test_folder.children.append(sub_folder)
49-
max_tests_in_file = max(max_tests_in_file, count_tests(sub_folder))
50-
elif os.path.isfile(entry_path) and entry.startswith("test_"):
51-
tests = 0
52-
with open(entry_path) as file:
53-
for line in file:
54-
if re.match(r"^(async\s+)?def\s+test_\w+\(", line):
55-
tests += 1
56-
test_folder.children.append(TestFile(entry_path, tests))
57-
max_tests_in_file = max(max_tests_in_file, tests)
58-
59-
return max_tests_in_file
29+
def get_paths_line(self) -> str:
30+
"""Return paths."""
31+
return " ".join(self._paths) + "\n"
6032

6133

6234
class BucketHolder:
@@ -66,38 +38,119 @@ def __init__(self, tests_per_bucket: int, bucket_count: int) -> None:
6638
"""Initialize bucket holder."""
6739
self._tests_per_bucket = tests_per_bucket
6840
self._bucket_count = bucket_count
69-
self._current_bucket = []
70-
self._current_tests = 0
71-
self._buckets: list[list[str]] = [self._current_bucket]
41+
self._current_bucket = Bucket()
42+
self._buckets: list[Bucket] = [self._current_bucket]
43+
self._last_bucket = False
7244

7345
def split_tests(self, tests: TestFolder | TestFile) -> None:
7446
"""Split tests into buckets."""
75-
if self._current_tests + tests.total_tests < self._tests_per_bucket:
76-
self._current_bucket.append(tests.path)
77-
self._current_tests += tests.total_tests
47+
if (
48+
self._current_bucket.tests + tests.total_tests < self._tests_per_bucket
49+
) or self._last_bucket:
50+
self._current_bucket.add(tests)
7851
return
7952

8053
if isinstance(tests, TestFolder):
81-
for test in tests.children:
54+
for test in tests.children.values():
8255
self.split_tests(test)
8356
return
8457

8558
# Create new bucket
86-
self._current_tests = 0
87-
88-
# The last bucket is lightly bigger (max the maximum number of tests in a single file)
89-
if len(self._buckets) != self._bucket_count:
90-
self._current_bucket = []
59+
if len(self._buckets) == self._bucket_count:
60+
# Last bucket, add all tests to it
61+
self._last_bucket = True
62+
else:
63+
self._current_bucket = Bucket()
9164
self._buckets.append(self._current_bucket)
9265

9366
# Add test to new bucket
9467
self.split_tests(tests)
9568

96-
def create_ouput_files(self) -> None:
97-
"""Create output files."""
69+
def create_ouput_file(self) -> None:
70+
"""Create output file."""
9871
with open("pytest_buckets.txt", "w") as file:
9972
for bucket in self._buckets:
100-
file.write(" ".join(bucket) + "\n")
73+
print(f"Bucket has {bucket.tests} tests")
74+
file.write(bucket.get_paths_line())
75+
76+
77+
@dataclass
78+
class TestFile:
79+
"""Class to hold number of tests."""
80+
81+
path: str
82+
total_tests: int
83+
84+
def __gt__(self, other):
85+
"""Return if greater than."""
86+
return self.total_tests > other.total_tests
87+
88+
89+
@dataclass
90+
class TestFolder:
91+
"""Class to hold test information."""
92+
93+
path: str
94+
children: dict[str, TestFolder | TestFile] = field(default_factory=dict)
95+
96+
@property
97+
def total_tests(self) -> int:
98+
"""Return total tests."""
99+
return sum([test.total_tests for test in self.children.values()])
100+
101+
def __repr__(self):
102+
"""Return representation."""
103+
return f"TestFolder(total={self.total_tests}, children={len(self.children)})"
104+
105+
106+
def insert_at_correct_position(
107+
test_holder: TestFolder, test_path: str, total_tests: int
108+
) -> None:
109+
"""Insert test at correct position."""
110+
current_path = test_holder
111+
for part in test_path.split("/")[1:]:
112+
if part.endswith(".py"):
113+
current_path.children[part] = TestFile(test_path, total_tests)
114+
else:
115+
current_path = current_path.children.setdefault(
116+
part, TestFolder(os.path.join(current_path.path, part))
117+
)
118+
119+
120+
def collect_tests(path: str) -> tuple[TestFolder, TestFile]:
121+
"""Collect all tests."""
122+
result = subprocess.run(
123+
["pytest", "--collect-only", "-qq", "-p", "no:warnings", path],
124+
check=False,
125+
capture_output=True,
126+
text=True,
127+
)
128+
129+
if result.returncode != 0:
130+
print("Failed to collect tests:")
131+
print(result.stderr)
132+
print(result.stdout)
133+
sys.exit(1)
134+
135+
folder = TestFolder(path.split("/")[0])
136+
insert_at_correct_position(folder, path, 0)
137+
max_tests_in_file = TestFile("", 0)
138+
139+
for line in result.stdout.splitlines():
140+
if not line.strip():
141+
continue
142+
parts = [x.strip() for x in line.split(":")]
143+
if len(parts) != 2:
144+
print(f"Unexpected line: {line}")
145+
sys.exit(1)
146+
147+
path = parts[0]
148+
total_tests = int(parts[1])
149+
max_tests_in_file = max(max_tests_in_file, TestFile(path, total_tests))
150+
151+
insert_at_correct_position(folder, path, total_tests)
152+
153+
return (folder, max_tests_in_file)
101154

102155

103156
def main() -> None:
@@ -120,22 +173,23 @@ def check_greater_0(value: str) -> int:
120173

121174
arguments = parser.parse_args()
122175

123-
tests = TestFolder("tests")
124-
max_tests_in_file = count_tests(tests)
125-
print(f"Maximum tests in a single file: {max_tests_in_file}")
176+
(tests, max_tests_in_file) = collect_tests("tests")
177+
print(
178+
f"Maximum tests in a single file are {max_tests_in_file.total_tests} tests (in {max_tests_in_file.path})"
179+
)
126180
print(f"Total tests: {tests.total_tests}")
127181

128182
tests_per_bucket = ceil(tests.total_tests / arguments.bucket_count)
129183
print(f"Estimated tests per bucket: {tests_per_bucket}")
130184

131-
if max_tests_in_file > tests_per_bucket:
185+
if max_tests_in_file.total_tests > tests_per_bucket:
132186
raise ValueError(
133187
f"There are more tests in a single file ({max_tests_in_file}) than tests per bucket ({tests_per_bucket})"
134188
)
135189

136190
bucket_holder = BucketHolder(tests_per_bucket, arguments.bucket_count)
137191
bucket_holder.split_tests(tests)
138-
bucket_holder.create_ouput_files()
192+
bucket_holder.create_ouput_file()
139193

140194

141195
if __name__ == "__main__":

0 commit comments

Comments
 (0)