Skip to content

Commit 0a5cca6

Browse files
tchatonawaelchlithomaspre-commit-ci[bot]
authored
StreamingDataset: Cleanup chunks right away if the dataset doesn't fit within the cache (#19168)
Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: thomas <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent ecdfab0 commit 0a5cca6

File tree

5 files changed

+80
-58
lines changed

5 files changed

+80
-58
lines changed

src/lightning/data/streaming/config.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __init__(
6464
self._downloader = None
6565

6666
if remote_dir:
67-
self._downloader = get_downloader_cls(remote_dir)(remote_dir, cache_dir, self._chunks)
67+
self._downloader = get_downloader_cls(remote_dir, cache_dir, self._chunks)
6868

6969
def download_chunk_from_index(self, chunk_index: int) -> None:
7070
chunk_filename = self._chunks[chunk_index]["filename"]
@@ -85,6 +85,12 @@ def intervals(self) -> List[Tuple[int, int]]:
8585
raise RuntimeError("The intervals should be defined.")
8686
return self._intervals
8787

88+
@property
89+
def num_bytes(self) -> int:
90+
if self._config is None:
91+
raise RuntimeError("The config should be defined.")
92+
return sum(c["chunk_bytes"] for c in self._chunks)
93+
8894
@property
8995
def data_format(self) -> Any:
9096
if self._config is None:
@@ -146,7 +152,7 @@ def load(
146152
cache_index_filepath = os.path.join(cache_dir, _INDEX_FILENAME)
147153

148154
if isinstance(remote_dir, str):
149-
downloader = get_downloader_cls(remote_dir)(remote_dir, cache_dir, [])
155+
downloader = get_downloader_cls(remote_dir, cache_dir, [])
150156
downloader.download_file(os.path.join(remote_dir, _INDEX_FILENAME), cache_index_filepath)
151157

152158
if not os.path.exists(cache_index_filepath):

src/lightning/data/streaming/downloader.py

+11-13
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
# limitations under the License.
1313
import os
1414
import shutil
15-
from abc import ABC, abstractmethod
16-
from typing import Any, Dict, List, Type
15+
from abc import ABC
16+
from typing import Any, Dict, List
1717
from urllib import parse
1818

1919
from lightning.data.streaming.client import S3Client
@@ -31,28 +31,27 @@ def download_chunk_from_index(self, chunk_index: int) -> None:
3131
remote_chunkpath = os.path.join(self._remote_dir, chunk_filename)
3232
self.download_file(remote_chunkpath, local_chunkpath)
3333

34-
@abstractmethod
3534
def download_file(self, remote_chunkpath: str, local_chunkpath: str) -> None:
3635
pass
3736

3837

3938
class S3Downloader(Downloader):
40-
@classmethod
41-
def download_file(cls, remote_filepath: str, local_filepath: str) -> None:
39+
def __init__(self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]]):
40+
super().__init__(remote_dir, cache_dir, chunks)
41+
self._client = S3Client()
42+
43+
def download_file(self, remote_filepath: str, local_filepath: str) -> None:
4244
obj = parse.urlparse(remote_filepath)
4345

4446
if obj.scheme != "s3":
4547
raise ValueError(f"Expected obj.scheme to be `s3`, instead, got {obj.scheme} for remote={remote_filepath}")
4648

47-
# TODO: Add caching to avoid re-creating it
48-
s3 = S3Client()
49-
5049
from boto3.s3.transfer import TransferConfig
5150

5251
extra_args: Dict[str, Any] = {}
5352

5453
# Issue: https://github.com/boto/boto3/issues/3113
55-
s3.client.download_file(
54+
self._client.client.download_file(
5655
obj.netloc,
5756
obj.path.lstrip("/"),
5857
local_filepath,
@@ -62,8 +61,7 @@ def download_file(cls, remote_filepath: str, local_filepath: str) -> None:
6261

6362

6463
class LocalDownloader(Downloader):
65-
@classmethod
66-
def download_file(cls, remote_filepath: str, local_filepath: str) -> None:
64+
def download_file(self, remote_filepath: str, local_filepath: str) -> None:
6765
if not os.path.exists(remote_filepath):
6866
raise FileNotFoundError(f"The provided remote_path doesn't exist: {remote_filepath}")
6967
if remote_filepath != local_filepath:
@@ -73,8 +71,8 @@ def download_file(cls, remote_filepath: str, local_filepath: str) -> None:
7371
_DOWNLOADERS = {"s3://": S3Downloader, "": LocalDownloader}
7472

7573

76-
def get_downloader_cls(remote_dir: str) -> Type[Downloader]:
74+
def get_downloader_cls(remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]]) -> Downloader:
7775
for k, cls in _DOWNLOADERS.items():
7876
if str(remote_dir).startswith(k):
79-
return cls
77+
return cls(remote_dir, cache_dir, chunks)
8078
raise ValueError(f"The provided `remote_dir` {remote_dir} doesn't have a downloader associated.")

src/lightning/data/streaming/item_loader.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str
9090
first_exists = exists = os.path.exists(chunk_filepath)
9191

9292
while not exists:
93-
sleep(0.01)
93+
sleep(0.1)
9494
exists = os.path.exists(chunk_filepath)
9595

9696
# Wait to avoid any corruption when the file appears
@@ -166,7 +166,6 @@ def generate_intervals(self) -> List[Tuple[int, int]]:
166166
def _load_chunk(self, chunk_index: int, chunk_filepath: str) -> None:
167167
if chunk_index in self._mmaps:
168168
return
169-
170169
chunk = self._chunks[chunk_index]
171170

172171
# Skip the header
@@ -192,7 +191,7 @@ def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str
192191
first_exists = exists = os.path.exists(chunk_filepath)
193192

194193
while not exists:
195-
sleep(0.01)
194+
sleep(0.1)
196195
exists = os.path.exists(chunk_filepath)
197196

198197
# Wait to avoid any corruption when the file appears
@@ -202,7 +201,6 @@ def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str
202201
self._chunk_filepaths[chunk_filepath] = True
203202

204203
self._load_chunk(chunk_index, chunk_filepath)
205-
206204
assert self._dtype
207205

208206
buffer: bytes = self._buffers[chunk_index]

src/lightning/data/streaming/reader.py

+36-37
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from logging import Logger
1919
from queue import Empty
2020
from threading import Thread
21-
from time import sleep
2221
from typing import Any, Dict, List, Optional, Tuple, Union
2322

2423
from lightning.data.streaming.config import ChunksConfig
@@ -30,14 +29,21 @@
3029

3130
warnings.filterwarnings("ignore", message=".*The given buffer is not writable.*")
3231

33-
3432
if _TORCH_GREATER_EQUAL_2_1_0:
3533
pass
3634

3735

3836
logger = Logger(__name__)
3937

4038

39+
_END_TOKEN = "END"
40+
41+
# Note: The timeout here should not be too short. We need to prevent the caller from aggressively
42+
# querying the queue and consuming too many CPU cycles.
43+
_DEFAULT_TIMEOUT = 0.1
44+
_LONG_DEFAULT_TIMEOUT = 5
45+
46+
4147
class PrepareChunksThread(Thread):
4248
"""This thread is responsible to download the chunks associated to a given worker."""
4349

@@ -59,22 +65,7 @@ def __init__(
5965
self._parent_cache_dir = os.path.dirname(self._config._cache_dir)
6066
self._to_download_queue: multiprocessing.Queue = multiprocessing.Queue()
6167
self._to_delete_queue: multiprocessing.Queue = multiprocessing.Queue()
62-
self._to_stop_queue: multiprocessing.Queue = multiprocessing.Queue()
63-
64-
# populate back the queues with existing items. As they already exists, this is almost a no-op
65-
for chunk_index in self._collect_ordered_chunk_indexes_from_cache():
66-
self._to_download_queue.put(chunk_index)
67-
self._to_delete_queue.put(chunk_index)
68-
69-
def _collect_ordered_chunk_indexes_from_cache(self) -> List[int]:
70-
"""List the chunks available in the cache, order them based on their creation time and retrieves their
71-
indexes."""
72-
chunk_indexes = [
73-
[self._config._get_chunk_index_from_filename(f), os.path.getctime(os.path.join(self._config._cache_dir, f))]
74-
for f in os.listdir(self._config._cache_dir)
75-
if f.endswith(".bin")
76-
]
77-
return [int(x[0]) for x in sorted(chunk_indexes, key=lambda x: x[1])]
68+
self._delete_chunks_when_processed = self._config.num_bytes > max_cache_size if max_cache_size else False
7869

7970
def download(self, chunk_indexes: List[int]) -> None:
8071
"""Receive the list of the chunk indices to download for the current epoch."""
@@ -93,10 +84,15 @@ def _delete(self, chunk_index: int) -> None:
9384

9485
def stop(self) -> None:
9586
"""Receive the list of the chunk indices to download for the current epoch."""
96-
self._to_stop_queue.put(True)
87+
self._to_download_queue.put(_END_TOKEN)
9788

9889
def _maybe_delete_chunks(self) -> None:
99-
chunk_index = _get_from_queue(self._to_delete_queue)
90+
reached_pre_download = self._pre_download_counter == self._max_pre_download
91+
92+
# we have already pre-downloaded some chunks, we just need to wait for them to be processed.
93+
chunk_index = _get_from_queue(
94+
self._to_delete_queue, timeout=_LONG_DEFAULT_TIMEOUT if reached_pre_download else _DEFAULT_TIMEOUT
95+
)
10096

10197
if chunk_index is not None:
10298
self._pre_download_counter -= 1
@@ -105,14 +101,17 @@ def _maybe_delete_chunks(self) -> None:
105101
self._chunks_index_to_be_deleted.append(chunk_index)
106102

107103
# Get the current cache size and decide whether we need to start cleanup. Otherwise, keep track of it
108-
while (
109-
self._max_cache_size
110-
and self._chunks_index_to_be_deleted
111-
and _get_folder_size(self._parent_cache_dir) >= self._max_cache_size
112-
):
104+
while self._max_cache_size and self._chunks_index_to_be_deleted and self._can_delete_chunk():
113105
# Delete the oldest chunk
114106
self._delete(self._chunks_index_to_be_deleted.pop(0))
115107

108+
return
109+
110+
def _can_delete_chunk(self) -> bool:
111+
if self._delete_chunks_when_processed:
112+
return self._pre_download_counter == self._max_pre_download - 1
113+
return self._max_cache_size is not None and _get_folder_size(self._parent_cache_dir) >= self._max_cache_size
114+
116115
def _pre_load_chunk(self, chunk_index: int) -> None:
117116
chunk_filepath, _, _ = self._config[ChunkedIndex(index=-1, chunk_index=chunk_index)]
118117
self._item_loader.pre_load_chunk(chunk_index, chunk_filepath)
@@ -121,6 +120,9 @@ def run(self) -> None:
121120
while True:
122121
if self._pre_download_counter <= self._max_pre_download:
123122
chunk_index = _get_from_queue(self._to_download_queue)
123+
if chunk_index == _END_TOKEN:
124+
return
125+
124126
if chunk_index is not None:
125127
self._config.download_chunk_from_index(chunk_index)
126128

@@ -135,11 +137,6 @@ def run(self) -> None:
135137
if self._max_cache_size:
136138
self._maybe_delete_chunks()
137139

138-
if _get_from_queue(self._to_stop_queue):
139-
return
140-
141-
sleep(0.05)
142-
143140

144141
class BinaryReader:
145142
def __init__(
@@ -238,6 +235,9 @@ def read(self, index: ChunkedIndex) -> Any:
238235
assert self._prepare_thread
239236
self._prepare_thread.download([index.chunk_index])
240237

238+
if self._last_chunk_index is None:
239+
self._last_chunk_index = index.chunk_index
240+
241241
# Fetch the element
242242
chunk_filepath, begin, _ = self.config[index]
243243
item = self._item_loader.load_item_from_chunk(index.index, index.chunk_index, chunk_filepath, begin)
@@ -246,9 +246,10 @@ def read(self, index: ChunkedIndex) -> Any:
246246
# Otherwise, this could trigger segmentation fault error depending on the item loader used.
247247
if self._config and self._config._remote_dir and index.chunk_index != self._last_chunk_index:
248248
assert self._prepare_thread
249-
if self._last_chunk_index is not None:
250-
# inform the chunk has been completely consumed
251-
self._prepare_thread.delete([self._last_chunk_index])
249+
assert self._last_chunk_index is not None
250+
251+
# inform the chunk has been completely consumed
252+
self._prepare_thread.delete([self._last_chunk_index])
252253

253254
# track the new chunk index as the latest one
254255
self._last_chunk_index = index.chunk_index
@@ -294,11 +295,9 @@ def _get_folder_size(path: str) -> int:
294295
return size
295296

296297

297-
def _get_from_queue(queue: multiprocessing.Queue) -> Optional[Any]:
298+
def _get_from_queue(queue: multiprocessing.Queue, timeout: float = _DEFAULT_TIMEOUT) -> Optional[Any]:
298299
try:
299-
# Note: The timeout here should not be too short. We need to prevent the caller from aggressively
300-
# querying the queue and consuming too many CPU cycles.
301-
return queue.get(timeout=0.1)
300+
return queue.get(timeout=timeout)
302301
except Empty:
303302
pass
304303
except OSError as e:

tests/tests_data/streaming/test_reader.py

+23-2
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
import numpy as np
55
from lightning.data.streaming.cache import Cache
66
from lightning.data.streaming.config import ChunkedIndex
7-
from lightning.data.streaming.reader import _get_folder_size
7+
from lightning.data.streaming.item_loader import PyTreeLoader
8+
from lightning.data.streaming.reader import PrepareChunksThread, _get_folder_size
89
from lightning_cloud.resolver import Dir
910

1011

11-
def test_reader_chunk_removal(tmpdir, monkeypatch):
12+
def test_reader_chunk_removal(tmpdir):
1213
cache_dir = os.path.join(tmpdir, "cache_dir")
1314
remote_dir = os.path.join(tmpdir, "remote_dir")
1415
os.makedirs(cache_dir, exist_ok=True)
@@ -79,3 +80,23 @@ def test_get_folder_size(tmpdir):
7980
np.save(os.path.join(tmpdir, "array_2.npy"), array)
8081

8182
assert _get_folder_size(tmpdir) == 928 * 2
83+
84+
85+
def test_prepare_chunks_thread(tmpdir):
86+
cache_dir = os.path.join(tmpdir, "cache_dir")
87+
os.makedirs(cache_dir, exist_ok=True)
88+
cache = Cache(input_dir=cache_dir, chunk_size=2, max_cache_size=28020)
89+
90+
for i in range(25):
91+
cache[i] = i
92+
93+
cache.done()
94+
cache.merge()
95+
96+
cache._reader._try_load_config()
97+
98+
thread = PrepareChunksThread(cache._reader.config, item_loader=PyTreeLoader(), max_cache_size=1)
99+
assert thread._delete_chunks_when_processed
100+
101+
thread = PrepareChunksThread(cache._reader.config, item_loader=PyTreeLoader(), max_cache_size=10000)
102+
assert not thread._delete_chunks_when_processed

0 commit comments

Comments
 (0)