diff --git a/python/cuml/tests/conftest.py b/python/cuml/tests/conftest.py index fca925e274..e42fe5920d 100644 --- a/python/cuml/tests/conftest.py +++ b/python/cuml/tests/conftest.py @@ -264,6 +264,8 @@ def pytest_configure(config): else: hypothesis.settings.load_profile("unit") + config.pluginmanager.register(DownloadDataPlugin(), "download_data") + def pytest_pyfunc_call(pyfuncitem): """Skip tests that require the cudf.pandas accelerator. @@ -403,6 +405,23 @@ def random_seed(request): # ============================================================================= +class DownloadDataPlugin: + """Download data before workers are spawned. + + This avoids downloading data in each worker, which can lead to races. + """ + + def pytest_configure(self, config): + if not hasattr(config, "workerinput"): + # We're in the controller process, not a worker. Let's fetch all + # the datasets we might use. + fetch_20newsgroups() + fetch_california_housing() + datasets.load_digits() + datasets.load_diabetes() + datasets.load_breast_cancer() + + def dataset_fetch_retry(func, attempts=3, min_wait=1, max_wait=10): """Decorator for retrying dataset fetching operations with exponential backoff.