Skip to content

Commit

Permalink
Add disk usage check before downloading files (#19041)
Browse files Browse the repository at this point in the history
Co-authored-by: thomas <[email protected]>
  • Loading branch information
tchaton and thomas authored Nov 21, 2023
1 parent 49caddd commit d3df127
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 8 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci-tests-data.yml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ jobs:
- name: Testing Data
working-directory: tests/tests_data
# NOTE: do not include coverage report here, see: https://github.com/nedbat/coveragepy/issues/1003
timeout-minutes: 10
run: |
python -m coverage run --source lightning \
-m pytest -v --timeout=60 --durations=60
Expand Down
30 changes: 22 additions & 8 deletions src/lightning/data/streaming/data_processor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging
import os
import shutil
import signal
import tempfile
import traceback
Expand All @@ -10,7 +11,6 @@
from datetime import datetime
from multiprocessing import Process, Queue
from queue import Empty
from shutil import copyfile, rmtree
from time import sleep, time
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
from urllib import parse
Expand Down Expand Up @@ -101,6 +101,16 @@ def _wait_for_file_to_exist(s3: S3Client, obj: parse.ParseResult, sleep_time: in
raise e


def _wait_for_disk_usage_higher_than_threshold(input_dir: str, threshold_in_gb: int = 25, sleep_time: int = 3) -> None:
usage = shutil.disk_usage(input_dir)

while (usage.free / 1000 / 1000 / 1000) <= threshold_in_gb:
sleep(sleep_time)
usage = shutil.disk_usage(input_dir)

return


def _download_data_target(input_dir: Dir, cache_dir: str, queue_in: Queue, queue_out: Queue) -> None:
"""This function is used to download data from a remote directory to a cache directory to optimise reading."""
s3 = S3Client()
Expand All @@ -123,7 +133,11 @@ def _download_data_target(input_dir: Dir, cache_dir: str, queue_in: Queue, queue
continue

if input_dir.url is not None or input_dir.path is not None:
# 6. Download all the required paths to unblock the current index
if input_dir.url:
# 6. Wait for the removers to catch up when we are downloading data.
_wait_for_disk_usage_higher_than_threshold("/", 25)

# 7. Download all the required paths to unblock the current index
for path in paths:
local_path = path.replace(input_dir.path, cache_dir)

Expand All @@ -141,7 +155,7 @@ def _download_data_target(input_dir: Dir, cache_dir: str, queue_in: Queue, queue
s3.client.download_fileobj(obj.netloc, obj.path.lstrip("/"), f)

elif os.path.isfile(path):
copyfile(path, local_path)
shutil.copyfile(path, local_path)
else:
raise ValueError(f"The provided {input_dir.url} isn't supported.")

Expand Down Expand Up @@ -198,7 +212,7 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_
except Exception as e:
print(e)
elif os.path.isdir(output_dir.path):
copyfile(local_filepath, os.path.join(output_dir.path, os.path.basename(local_filepath)))
shutil.copyfile(local_filepath, os.path.join(output_dir.path, os.path.basename(local_filepath)))
else:
raise ValueError(f"The provided {output_dir.path} isn't supported.")

Expand Down Expand Up @@ -686,7 +700,7 @@ def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_ra
local_filepath, obj.netloc, os.path.join(obj.path.lstrip("/"), os.path.basename(local_filepath))
)
elif os.path.isdir(output_dir.path):
copyfile(local_filepath, os.path.join(output_dir.path, os.path.basename(local_filepath)))
shutil.copyfile(local_filepath, os.path.join(output_dir.path, os.path.basename(local_filepath)))

if num_nodes == 1 or node_rank is None:
return
Expand All @@ -707,7 +721,7 @@ def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_ra
with open(node_index_filepath, "wb") as f:
s3.client.download_fileobj(obj.netloc, obj.path.lstrip("/"), f)
elif os.path.isdir(output_dir.path):
copyfile(remote_filepath, node_index_filepath)
shutil.copyfile(remote_filepath, node_index_filepath)

merge_cache = Cache(cache_dir, chunk_bytes=1)
merge_cache._merge_no_wait()
Expand Down Expand Up @@ -948,15 +962,15 @@ def _cleanup_cache(self) -> None:

# Cleanup the cache dir folder to avoid corrupted files from previous run to be there.
if os.path.exists(cache_dir):
rmtree(cache_dir, ignore_errors=True)
shutil.rmtree(cache_dir, ignore_errors=True)

os.makedirs(cache_dir, exist_ok=True)

cache_data_dir = _get_cache_data_dir()

# Cleanup the cache data folder to avoid corrupted files from previous run to be there.
if os.path.exists(cache_data_dir):
rmtree(cache_data_dir, ignore_errors=True)
shutil.rmtree(cache_data_dir, ignore_errors=True)

os.makedirs(cache_data_dir, exist_ok=True)

Expand Down
9 changes: 9 additions & 0 deletions tests/tests_data/streaming/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
_map_items_to_workers_weighted,
_remove_target,
_upload_fn,
_wait_for_disk_usage_higher_than_threshold,
_wait_for_file_to_exist,
)
from lightning.data.streaming.functions import LambdaDataTransformRecipe, map, optimize
Expand Down Expand Up @@ -159,6 +160,7 @@ def fn(*_, **__):


@pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows")
@mock.patch("lightning.data.streaming.data_processor._wait_for_disk_usage_higher_than_threshold")
def test_download_data_target(tmpdir):
input_dir = os.path.join(tmpdir, "input_dir")
os.makedirs(input_dir, exist_ok=True)
Expand Down Expand Up @@ -193,6 +195,13 @@ def fn(*_, **__):
assert os.listdir(cache_dir) == ["a.txt"]


def test_wait_for_disk_usage_higher_than_threshold():
disk_usage_mock = mock.Mock(side_effect=[mock.Mock(free=10e9), mock.Mock(free=10e9), mock.Mock(free=10e11)])
with mock.patch("lightning.data.streaming.data_processor.shutil.disk_usage", disk_usage_mock):
_wait_for_disk_usage_higher_than_threshold("/", 10, sleep_time=0)
assert disk_usage_mock.call_count == 3


@pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows")
def test_wait_for_file_to_exist():
import botocore
Expand Down

0 comments on commit d3df127

Please sign in to comment.