Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add disk usage check before downloading files #19041

Merged
merged 12 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci-tests-data.yml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ jobs:
- name: Testing Data
working-directory: tests/tests_data
# NOTE: do not include coverage report here, see: https://github.com/nedbat/coveragepy/issues/1003
timeout-minutes: 10
run: |
python -m coverage run --source lightning \
-m pytest -v --timeout=60 --durations=60
Expand Down
30 changes: 22 additions & 8 deletions src/lightning/data/streaming/data_processor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging
import os
import shutil
import signal
import tempfile
import traceback
Expand All @@ -10,7 +11,6 @@
from datetime import datetime
from multiprocessing import Process, Queue
from queue import Empty
from shutil import copyfile, rmtree
from time import sleep, time
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
from urllib import parse
Expand Down Expand Up @@ -101,6 +101,16 @@ def _wait_for_file_to_exist(s3: S3Client, obj: parse.ParseResult, sleep_time: in
raise e


def _wait_for_disk_usage_higher_than_threshold(threshold_in_gb: int = 25, sleep_time: int = 3) -> None:
usage = shutil.disk_usage("/")

while (usage.free / 1000 / 1000 / 1000) <= threshold_in_gb:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
sleep(sleep_time)
usage = shutil.disk_usage("/")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should probably check the disk usage with the path where we are writing to. E.g. instead of / it should be the "cache dir" probably.

Copy link
Contributor Author

@tchaton tchaton Nov 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is safer to check the overall machine state to avoid any failures as this would be a costly mistake.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My argument is that it would be a costly mistake not to measure the disk usage where we are saving. For example, if your /home is not on the same disk as /. The function could simply take a path as input, like

 _wait_for_disk_usage_higher_than_threshold(input_dir.path)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The root should include everything on the machine, so it would independent on wherever is home.


return


def _download_data_target(input_dir: Dir, cache_dir: str, queue_in: Queue, queue_out: Queue) -> None:
"""This function is used to download data from a remote directory to a cache directory to optimise reading."""
s3 = S3Client()
Expand All @@ -123,7 +133,11 @@ def _download_data_target(input_dir: Dir, cache_dir: str, queue_in: Queue, queue
continue

if input_dir.url is not None or input_dir.path is not None:
# 6. Download all the required paths to unblock the current index
if input_dir.url:
# 6. Wait for the removers to catch up when we are downloading data.
_wait_for_disk_usage_higher_than_threshold(25)

# 7. Download all the required paths to unblock the current index
for path in paths:
local_path = path.replace(input_dir.path, cache_dir)

Expand All @@ -141,7 +155,7 @@ def _download_data_target(input_dir: Dir, cache_dir: str, queue_in: Queue, queue
s3.client.download_fileobj(obj.netloc, obj.path.lstrip("/"), f)

elif os.path.isfile(path):
copyfile(path, local_path)
shutil.copyfile(path, local_path)
else:
raise ValueError(f"The provided {input_dir.url} isn't supported.")

Expand Down Expand Up @@ -198,7 +212,7 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_
except Exception as e:
print(e)
elif os.path.isdir(output_dir.path):
copyfile(local_filepath, os.path.join(output_dir.path, os.path.basename(local_filepath)))
shutil.copyfile(local_filepath, os.path.join(output_dir.path, os.path.basename(local_filepath)))
else:
raise ValueError(f"The provided {output_dir.path} isn't supported.")

Expand Down Expand Up @@ -686,7 +700,7 @@ def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_ra
local_filepath, obj.netloc, os.path.join(obj.path.lstrip("/"), os.path.basename(local_filepath))
)
elif os.path.isdir(output_dir.path):
copyfile(local_filepath, os.path.join(output_dir.path, os.path.basename(local_filepath)))
shutil.copyfile(local_filepath, os.path.join(output_dir.path, os.path.basename(local_filepath)))

if num_nodes == 1 or node_rank is None:
return
Expand All @@ -707,7 +721,7 @@ def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_ra
with open(node_index_filepath, "wb") as f:
s3.client.download_fileobj(obj.netloc, obj.path.lstrip("/"), f)
elif os.path.isdir(output_dir.path):
copyfile(remote_filepath, node_index_filepath)
shutil.copyfile(remote_filepath, node_index_filepath)

merge_cache = Cache(cache_dir, chunk_bytes=1)
merge_cache._merge_no_wait()
Expand Down Expand Up @@ -948,15 +962,15 @@ def _cleanup_cache(self) -> None:

# Cleanup the cache dir folder to avoid corrupted files from previous run to be there.
if os.path.exists(cache_dir):
rmtree(cache_dir, ignore_errors=True)
shutil.rmtree(cache_dir, ignore_errors=True)

os.makedirs(cache_dir, exist_ok=True)

cache_data_dir = _get_cache_data_dir()

# Cleanup the cache data folder to avoid corrupted files from previous run to be there.
if os.path.exists(cache_data_dir):
rmtree(cache_data_dir, ignore_errors=True)
shutil.rmtree(cache_data_dir, ignore_errors=True)

os.makedirs(cache_data_dir, exist_ok=True)

Expand Down
38 changes: 37 additions & 1 deletion tests/tests_data/streaming/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
_map_items_to_workers_weighted,
_remove_target,
_upload_fn,
_wait_for_disk_usage_higher_than_threshold,
_wait_for_file_to_exist,
shutil,
)
from lightning.data.streaming.functions import LambdaDataTransformRecipe, map, optimize
from lightning_utilities.core.imports import RequirementCache
Expand Down Expand Up @@ -159,7 +161,7 @@ def fn(*_, **__):


@pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows")
def test_download_data_target(tmpdir):
def test_download_data_target(monkeypatch, tmpdir):
input_dir = os.path.join(tmpdir, "input_dir")
os.makedirs(input_dir, exist_ok=True)

Expand All @@ -176,6 +178,20 @@ def test_download_data_target(tmpdir):

paths = [os.path.join(input_dir, "a.txt"), None]

class Usage:
def __init__(self, free):
self.free = free * 1000 * 1000 * 1000

usages = [Usage(100)]

def fn(*_, **__):
value = usages.pop(0)
if value is None:
return value
return value

monkeypatch.setattr(shutil, "disk_usage", fn)

def fn(*_, **__):
value = paths.pop(0)
if value is None:
Expand All @@ -193,6 +209,26 @@ def fn(*_, **__):
assert os.listdir(cache_dir) == ["a.txt"]


def test_wait_for_disk_usage_higher_than_threshold(monkeypatch):
class Usage:
def __init__(self, free):
self.free = free * 1000 * 1000 * 1000

usages = [Usage(1), Usage(1), Usage(100)]

def fn(*_, **__):
value = usages.pop(0)
if value is None:
return value
return value

monkeypatch.setattr(shutil, "disk_usage", fn)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

_wait_for_disk_usage_higher_than_threshold(10, sleep_time=0)

assert len(usages) == 0


@pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows")
def test_wait_for_file_to_exist():
import botocore
Expand Down
Loading