diff --git a/.github/workflows/ci-tests-data.yml b/.github/workflows/ci-tests-data.yml index fccb4394fa999..baf37c27944a7 100644 --- a/.github/workflows/ci-tests-data.yml +++ b/.github/workflows/ci-tests-data.yml @@ -95,6 +95,7 @@ jobs: - name: Testing Data working-directory: tests/tests_data # NOTE: do not include coverage report here, see: https://github.com/nedbat/coveragepy/issues/1003 + timeout-minutes: 10 run: | python -m coverage run --source lightning \ -m pytest -v --timeout=60 --durations=60 diff --git a/src/lightning/data/streaming/data_processor.py b/src/lightning/data/streaming/data_processor.py index 4b213a2a3fcde..8431cc70845e8 100644 --- a/src/lightning/data/streaming/data_processor.py +++ b/src/lightning/data/streaming/data_processor.py @@ -1,6 +1,7 @@ import json import logging import os +import shutil import signal import tempfile import traceback @@ -10,7 +11,6 @@ from datetime import datetime from multiprocessing import Process, Queue from queue import Empty -from shutil import copyfile, rmtree from time import sleep, time from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union from urllib import parse @@ -101,6 +101,16 @@ def _wait_for_file_to_exist(s3: S3Client, obj: parse.ParseResult, sleep_time: in raise e +def _wait_for_disk_usage_higher_than_threshold(input_dir: str, threshold_in_gb: int = 25, sleep_time: int = 3) -> None: + usage = shutil.disk_usage(input_dir) + + while (usage.free / 1000 / 1000 / 1000) <= threshold_in_gb: + sleep(sleep_time) + usage = shutil.disk_usage(input_dir) + + return + + def _download_data_target(input_dir: Dir, cache_dir: str, queue_in: Queue, queue_out: Queue) -> None: """This function is used to download data from a remote directory to a cache directory to optimise reading.""" s3 = S3Client() @@ -123,7 +133,11 @@ def _download_data_target(input_dir: Dir, cache_dir: str, queue_in: Queue, queue continue if input_dir.url is not None or input_dir.path is not None: - # 6. Download all the required paths to unblock the current index + if input_dir.url: + # 6. Wait for the removers to catch up when we are downloading data. + _wait_for_disk_usage_higher_than_threshold("/", 25) + + # 7. Download all the required paths to unblock the current index for path in paths: local_path = path.replace(input_dir.path, cache_dir) @@ -141,7 +155,7 @@ def _download_data_target(input_dir: Dir, cache_dir: str, queue_in: Queue, queue s3.client.download_fileobj(obj.netloc, obj.path.lstrip("/"), f) elif os.path.isfile(path): - copyfile(path, local_path) + shutil.copyfile(path, local_path) else: raise ValueError(f"The provided {input_dir.url} isn't supported.") @@ -198,7 +212,7 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_ except Exception as e: print(e) elif os.path.isdir(output_dir.path): - copyfile(local_filepath, os.path.join(output_dir.path, os.path.basename(local_filepath))) + shutil.copyfile(local_filepath, os.path.join(output_dir.path, os.path.basename(local_filepath))) else: raise ValueError(f"The provided {output_dir.path} isn't supported.") @@ -686,7 +700,7 @@ def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_ra local_filepath, obj.netloc, os.path.join(obj.path.lstrip("/"), os.path.basename(local_filepath)) ) elif os.path.isdir(output_dir.path): - copyfile(local_filepath, os.path.join(output_dir.path, os.path.basename(local_filepath))) + shutil.copyfile(local_filepath, os.path.join(output_dir.path, os.path.basename(local_filepath))) if num_nodes == 1 or node_rank is None: return @@ -707,7 +721,7 @@ def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_ra with open(node_index_filepath, "wb") as f: s3.client.download_fileobj(obj.netloc, obj.path.lstrip("/"), f) elif os.path.isdir(output_dir.path): - copyfile(remote_filepath, node_index_filepath) + shutil.copyfile(remote_filepath, node_index_filepath) merge_cache = Cache(cache_dir, chunk_bytes=1) merge_cache._merge_no_wait() @@ -948,7 +962,7 @@ def _cleanup_cache(self) -> None: # Cleanup the cache dir folder to avoid corrupted files from previous run to be there. if os.path.exists(cache_dir): - rmtree(cache_dir, ignore_errors=True) + shutil.rmtree(cache_dir, ignore_errors=True) os.makedirs(cache_dir, exist_ok=True) @@ -956,7 +970,7 @@ def _cleanup_cache(self) -> None: # Cleanup the cache data folder to avoid corrupted files from previous run to be there. if os.path.exists(cache_data_dir): - rmtree(cache_data_dir, ignore_errors=True) + shutil.rmtree(cache_data_dir, ignore_errors=True) os.makedirs(cache_data_dir, exist_ok=True) diff --git a/tests/tests_data/streaming/test_data_processor.py b/tests/tests_data/streaming/test_data_processor.py index 45316e2dbb46f..610250cb26368 100644 --- a/tests/tests_data/streaming/test_data_processor.py +++ b/tests/tests_data/streaming/test_data_processor.py @@ -21,6 +21,7 @@ _map_items_to_workers_weighted, _remove_target, _upload_fn, + _wait_for_disk_usage_higher_than_threshold, _wait_for_file_to_exist, ) from lightning.data.streaming.functions import LambdaDataTransformRecipe, map, optimize @@ -159,6 +160,7 @@ def fn(*_, **__): @pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows") +@mock.patch("lightning.data.streaming.data_processor._wait_for_disk_usage_higher_than_threshold") def test_download_data_target(tmpdir): input_dir = os.path.join(tmpdir, "input_dir") os.makedirs(input_dir, exist_ok=True) @@ -193,6 +195,13 @@ def fn(*_, **__): assert os.listdir(cache_dir) == ["a.txt"] +def test_wait_for_disk_usage_higher_than_threshold(): + disk_usage_mock = mock.Mock(side_effect=[mock.Mock(free=10e9), mock.Mock(free=10e9), mock.Mock(free=10e11)]) + with mock.patch("lightning.data.streaming.data_processor.shutil.disk_usage", disk_usage_mock): + _wait_for_disk_usage_higher_than_threshold("/", 10, sleep_time=0) + assert disk_usage_mock.call_count == 3 + + @pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows") def test_wait_for_file_to_exist(): import botocore