Skip to content

Commit 4879531

Browse files
authored
[test fetcher] Always include the directly related test files (#30050)
* fix * fix --------- Co-authored-by: ydshieh <[email protected]>
1 parent de11d0b commit 4879531

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

utils/tests_fetcher.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -958,10 +958,25 @@ def has_many_models(tests):
958958
model_tests = {Path(t).parts[2] for t in tests if t.startswith("tests/models/")}
959959
return len(model_tests) > num_model_tests // 2
960960

961-
def filter_tests(tests):
962-
return [t for t in tests if not t.startswith("tests/models/") or Path(t).parts[2] in IMPORTANT_MODELS]
961+
# for each module (if specified in the argument `module`) of the form `models/my_model` (i.e. starting with it),
962+
# we always keep the tests (those are already in the argument `tests`) which are in `tests/models/my_model`.
963+
# This is to avoid them being excluded when a module has many impacted tests: the directly related test files should
964+
# always be included!
965+
def filter_tests(tests, module=""):
966+
return [
967+
t
968+
for t in tests
969+
if not t.startswith("tests/models/")
970+
or Path(t).parts[2] in IMPORTANT_MODELS
971+
# at this point, `t` is of the form `tests/models/my_model`, and we check if `models/my_model`
972+
# (i.e. `parts[1:3]`) is in `module`.
973+
or "/".join(Path(t).parts[1:3]) in module
974+
]
963975

964-
return {module: (filter_tests(tests) if has_many_models(tests) else tests) for module, tests in test_map.items()}
976+
return {
977+
module: (filter_tests(tests, module=module) if has_many_models(tests) else tests)
978+
for module, tests in test_map.items()
979+
}
965980

966981

967982
def check_imports_all_exist():

0 commit comments

Comments
 (0)