Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StreamingDataset improve deletion strategy #19118

Merged
merged 53 commits into from
Dec 7, 2023
Merged
Changes from 1 commit
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
daf5220
update
Dec 5, 2023
b5b4d34
Merge branch 'master' into improve_speed
tchaton Dec 5, 2023
c0d9164
update
Dec 5, 2023
e627385
Merge branch 'improve_speed' of https://github.com/Lightning-AI/light…
Dec 5, 2023
007b9f9
update
Dec 5, 2023
fd10ed0
update
Dec 5, 2023
7ad008a
update
Dec 5, 2023
1f65326
update
Dec 5, 2023
36a62c9
update
Dec 5, 2023
50d2b6d
update
Dec 5, 2023
257d831
update
Dec 5, 2023
2b82d44
update
tchaton Dec 5, 2023
a318e65
update
tchaton Dec 5, 2023
428a3f5
update
tchaton Dec 5, 2023
1d984d1
update
tchaton Dec 5, 2023
fb412a8
update
tchaton Dec 5, 2023
b785957
update
tchaton Dec 5, 2023
89869f1
update
tchaton Dec 5, 2023
b9c5e53
update
tchaton Dec 5, 2023
734adf6
update
tchaton Dec 5, 2023
7980c05
update
tchaton Dec 5, 2023
9c3acb0
update
Dec 5, 2023
b59c38c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 5, 2023
a5e10b7
update
tchaton Dec 6, 2023
3e452ab
update
tchaton Dec 6, 2023
687bd5c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2023
7ca362c
update
Dec 6, 2023
033feed
update
Dec 6, 2023
e613774
update
tchaton Dec 6, 2023
08292f5
update
tchaton Dec 6, 2023
347db4e
update
Dec 6, 2023
513e554
update
tchaton Dec 6, 2023
d844e35
Merge branch 'improve_speed_2' of https://github.com/Lightning-AI/pyt…
tchaton Dec 6, 2023
5f4d6c0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2023
34269d6
update
tchaton Dec 6, 2023
d5c5d89
update
tchaton Dec 6, 2023
f5f650d
update
tchaton Dec 6, 2023
879473f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2023
2a3acf7
update
Dec 6, 2023
e3b4389
update
tchaton Dec 6, 2023
771d49c
Merge branch 'improve_speed_2' of https://github.com/Lightning-AI/pyt…
tchaton Dec 6, 2023
282c550
update
tchaton Dec 6, 2023
6956504
update
Dec 6, 2023
49bd18f
update
Dec 6, 2023
838da1b
update
Dec 6, 2023
268b908
update
Dec 6, 2023
f22085f
update
Dec 6, 2023
bbc7887
update
Dec 6, 2023
ee3c288
remove_delete
Dec 6, 2023
3ae48b6
update
Dec 6, 2023
e4c9b72
update
Dec 6, 2023
db16b49
update
Dec 6, 2023
d6ac211
tune the timeout
awaelchli Dec 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update
tchaton committed Dec 5, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit 2b82d44fff4710031d275b95e7e2df3fd8f09963
2 changes: 1 addition & 1 deletion src/lightning/data/streaming/cache.py
Original file line number Diff line number Diff line change
@@ -51,7 +51,7 @@ def __init__(
chunk_size: Optional[int] = None,
chunk_bytes: Optional[Union[int, str]] = None,
item_loader: Optional[BaseItemLoader] = None,
max_cache_size: Union[int, str] = "200GB",
max_cache_size: Union[int, str] = "10GB",
serializers: Optional[Dict[str, Serializer]] = None,
):
"""The Cache enables to optimise dataset format for cloud training. This is done by grouping several elements
19 changes: 19 additions & 0 deletions src/lightning/data/streaming/item_loader.py
Original file line number Diff line number Diff line change
@@ -50,6 +50,11 @@ def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str
"""Returns an item loaded from a chunk."""
pass

@abstractmethod
def delete(self, chunk_index: int, chunk_filepath: str):
"""Delete a chunk."""
pass


class PyTreeLoader(BaseItemLoader):
"""The Pytree Loader is the default loader of the Cache object."""
@@ -106,6 +111,11 @@ def deserialize(self, raw_item_data: bytes) -> "PyTree":
idx += size
return tree_unflatten(data, self._config["data_spec"])

def delete(self, chunk_index: int, chunk_filepath: str):
if os.path.exists(chunk_filepath):
os.remove(chunk_filepath)
print(chunk_filepath)


class TokensLoader(BaseItemLoader):
def __init__(self, block_size: int):
@@ -180,3 +190,12 @@ def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str
buffer: bytes = self._buffers[chunk_index]
offset = self._dtype.itemsize * (index - begin) * self._block_size
return torch.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset)

def delete(self, chunk_index: int, chunk_filepath: str):
self._mmaps[chunk_index]._mmap.close()
del self._mmaps[chunk_index]
del self._buffers[chunk_index]

if os.path.exists(chunk_filepath):
os.remove(chunk_filepath)
print(chunk_filepath)
73 changes: 41 additions & 32 deletions src/lightning/data/streaming/reader.py
Original file line number Diff line number Diff line change
@@ -36,12 +36,14 @@
class PrepareChunksThread(Thread):
"""This thread is responsible to download the chunks associated to a given worker."""

def __init__(self, config: ChunksConfig, max_cache_size: Optional[int] = None) -> None:
def __init__(self, config: ChunksConfig, item_loader, max_cache_size: Optional[int] = None) -> None:
super().__init__(daemon=True)
self._config = config
self._item_loader = item_loader
self._chunks_index_to_be_downloaded: List[int] = []
self._chunks_index_to_be_deleted: List[int] = []
self._max_cache_size = max_cache_size
self._parent_cache_dir = os.path.dirname(self._config._cache_dir)
self._to_download_queue: multiprocessing.Queue = multiprocessing.Queue()
self._to_delete_queue: multiprocessing.Queue = multiprocessing.Queue()
self._to_stop_queue: multiprocessing.Queue = multiprocessing.Queue()
@@ -58,26 +60,42 @@ def delete(self, chunk_indexes: List[int]) -> None:

def _delete(self, chunk_index: int) -> None:
chunk_filepath, begin, _ = self._config[ChunkedIndex(index=-1, chunk_index=chunk_index)]

if os.path.exists(chunk_filepath):
os.remove(chunk_filepath)
self._item_loader.delete(chunk_index, chunk_filepath)

def stop(self) -> None:
"""Receive the list of the chunk indices to download for the current epoch."""
self._to_stop_queue.put(None)

def _delete_chunks(self):
try:
chunk_index = self._to_delete_queue.get(timeout=0.01)
if self._max_cache_size:
total = _get_folder_size(self._parent_cache_dir)
if total >= self._max_cache_size:
self._chunks_index_to_be_deleted.append(chunk_index)

while (self._max_cache_size and self._chunks_index_to_be_deleted and total >= self._max_cache_size):
self._delete(self._chunks_index_to_be_deleted.pop(0))
total = _get_folder_size(self._parent_cache_dir)
else:
self._chunks_index_to_be_deleted.append(chunk_index)
except Empty:
pass
except OSError as e:
# handle closed queue before the thread terminates
if "handle is closed" in str(e):
pass
else:
raise e

def run(self) -> None:
while True:
try:
chunk_index = self._to_download_queue.get(timeout=0.01)

# Before downloading, check whether we have enough space
if (
self._max_cache_size
and self._chunks_index_to_be_deleted
and shutil.disk_usage(self._config._cache_dir).total >= self._max_cache_size
):
self._delete(self._chunks_index_to_be_deleted.pop(0))
while (self._max_cache_size and _get_folder_size(self._parent_cache_dir) >= self._max_cache_size):
self._delete_chunks()

self._config.download_chunk_from_index(chunk_index)
except Empty:
@@ -89,27 +107,7 @@ def run(self) -> None:
else:
raise e

try:
chunk_index = self._to_delete_queue.get(timeout=0.01)
if self._max_cache_size:
if shutil.disk_usage(self._config._cache_dir).total >= self._max_cache_size:
self._chunks_index_to_be_deleted.append(chunk_index)

# Delete 2 chunk at the time to give enough space while not blocking downloads
for chunk_index in self._chunks_index_to_be_deleted[:2]:
self._delete(chunk_index)

self._chunks_index_to_be_deleted = self._chunks_index_to_be_deleted[2:]
else:
self._chunks_index_to_be_deleted.append(chunk_index)
except Empty:
pass
except OSError as e:
# handle closed queue before the thread terminates
if "handle is closed" in str(e):
pass
else:
raise e
self._delete_chunks()

try:
self._to_stop_queue.get(timeout=0.01)
@@ -212,7 +210,7 @@ def read(self, index: ChunkedIndex) -> Any:
if self._config and self._config._remote_dir:
# Create and start the prepare chunks thread
if self._prepare_thread is None and self._config:
self._prepare_thread = PrepareChunksThread(self._config, self._max_cache_size)
self._prepare_thread = PrepareChunksThread(self._config, self._item_loader, self._max_cache_size)
self._prepare_thread.start()
if index.chunk_indexes:
self._prepare_thread.download(index.chunk_indexes)
@@ -255,3 +253,14 @@ def __getstate__(self) -> Dict[str, Any]:
state = self.__dict__.copy()
state["_prepare_thread"] = None
return state


def _get_folder_size(path: str) -> int:
size = 0
for dirpath, _, filenames in os.walk(str(path)):
for filename in filenames:
try:
size += os.stat(os.path.join(dirpath, filename)).st_size
except FileNotFoundError:
pass
return size