Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StreamingDataset improve deletion strategy #19118

Merged
merged 53 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from 52 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
daf5220
update
Dec 5, 2023
b5b4d34
Merge branch 'master' into improve_speed
tchaton Dec 5, 2023
c0d9164
update
Dec 5, 2023
e627385
Merge branch 'improve_speed' of https://github.com/Lightning-AI/light…
Dec 5, 2023
007b9f9
update
Dec 5, 2023
fd10ed0
update
Dec 5, 2023
7ad008a
update
Dec 5, 2023
1f65326
update
Dec 5, 2023
36a62c9
update
Dec 5, 2023
50d2b6d
update
Dec 5, 2023
257d831
update
Dec 5, 2023
2b82d44
update
tchaton Dec 5, 2023
a318e65
update
tchaton Dec 5, 2023
428a3f5
update
tchaton Dec 5, 2023
1d984d1
update
tchaton Dec 5, 2023
fb412a8
update
tchaton Dec 5, 2023
b785957
update
tchaton Dec 5, 2023
89869f1
update
tchaton Dec 5, 2023
b9c5e53
update
tchaton Dec 5, 2023
734adf6
update
tchaton Dec 5, 2023
7980c05
update
tchaton Dec 5, 2023
9c3acb0
update
Dec 5, 2023
b59c38c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 5, 2023
a5e10b7
update
tchaton Dec 6, 2023
3e452ab
update
tchaton Dec 6, 2023
687bd5c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2023
7ca362c
update
Dec 6, 2023
033feed
update
Dec 6, 2023
e613774
update
tchaton Dec 6, 2023
08292f5
update
tchaton Dec 6, 2023
347db4e
update
Dec 6, 2023
513e554
update
tchaton Dec 6, 2023
d844e35
Merge branch 'improve_speed_2' of https://github.com/Lightning-AI/pyt…
tchaton Dec 6, 2023
5f4d6c0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2023
34269d6
update
tchaton Dec 6, 2023
d5c5d89
update
tchaton Dec 6, 2023
f5f650d
update
tchaton Dec 6, 2023
879473f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2023
2a3acf7
update
Dec 6, 2023
e3b4389
update
tchaton Dec 6, 2023
771d49c
Merge branch 'improve_speed_2' of https://github.com/Lightning-AI/pyt…
tchaton Dec 6, 2023
282c550
update
tchaton Dec 6, 2023
6956504
update
Dec 6, 2023
49bd18f
update
Dec 6, 2023
838da1b
update
Dec 6, 2023
268b908
update
Dec 6, 2023
f22085f
update
Dec 6, 2023
bbc7887
update
Dec 6, 2023
ee3c288
remove_delete
Dec 6, 2023
3ae48b6
update
Dec 6, 2023
e4c9b72
update
Dec 6, 2023
db16b49
update
Dec 6, 2023
d6ac211
tune the timeout
awaelchli Dec 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/lightning/data/streaming/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
chunk_size: Optional[int] = None,
chunk_bytes: Optional[Union[int, str]] = None,
item_loader: Optional[BaseItemLoader] = None,
max_cache_size: Union[int, str] = "200GB",
max_cache_size: Union[int, str] = "100GB",
serializers: Optional[Dict[str, Serializer]] = None,
):
"""The Cache enables to optimise dataset format for cloud training. This is done by grouping several elements
Expand Down
7 changes: 7 additions & 0 deletions src/lightning/data/streaming/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,13 @@ def __getitem__(self, index: ChunkedIndex) -> Tuple[str, int, int]:
chunk = self._chunks[index.chunk_index]
return os.path.join(self._cache_dir, chunk["filename"]), *self._intervals[index.chunk_index]

def _get_chunk_index_from_filename(self, chunk_filename: str) -> int:
"""Retrieves the associated chunk_index for a given chunk filename."""
for chunk_index, chunk in enumerate(self._chunks):
if chunk["filename"] == chunk_filename:
return chunk_index
raise ValueError(f"The provided filename doesn't exist {chunk_filename}.")

@classmethod
def load(
cls,
Expand Down
15 changes: 8 additions & 7 deletions src/lightning/data/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import sys
import tempfile
from dataclasses import dataclass
from datetime import datetime
from time import time
from typing import Any, Dict, List, Optional, Union

Expand All @@ -31,7 +30,6 @@
_DEFAULT_CACHE_DIR,
_INDEX_FILENAME,
_LIGHTNING_CLOUD_LATEST,
_TIME_FORMAT,
)
from lightning.data.streaming.item_loader import BaseItemLoader
from lightning.data.streaming.sampler import ChunkedIndex
Expand All @@ -56,6 +54,7 @@ def __init__(
seed: int = 42,
serializers: Optional[Dict[str, Serializer]] = None,
checkpoint_interval: Optional[int] = None,
max_cache_size: Union[int, str] = "100GB",
) -> None:
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class.

Expand All @@ -68,6 +67,7 @@ def __init__(
seed: Random seed for shuffling.
serializers: The serializers used to serialize and deserialize the chunks.
checkpoint_interval: Interval in seconds at which the workers are going to store their own progress.
max_cache_size: The maximum cache size used by the StreamingDataset.

"""
super().__init__()
Expand All @@ -84,6 +84,7 @@ def __init__(
self.shuffle: bool = shuffle
self.drop_last = drop_last
self.seed = seed
self.max_cache_size = max_cache_size

self.cache: Optional[Cache] = None
self.distributed_env = _DistributedEnv.detect()
Expand Down Expand Up @@ -118,7 +119,11 @@ def _create_cache(self, worker_env: _WorkerEnv) -> Cache:
self.input_dir.path = cache_path

cache = Cache(
input_dir=self.input_dir, item_loader=self.item_loader, chunk_bytes=1, serializers=self.serializers
input_dir=self.input_dir,
item_loader=self.item_loader,
chunk_bytes=1,
serializers=self.serializers,
max_cache_size=self.max_cache_size,
)
cache._reader._try_load_config()

Expand Down Expand Up @@ -391,10 +396,6 @@ def _try_create_cache_dir(input_dir: str, shard_rank: int = 0) -> Optional[str]:
return cache_dir


def _string_to_datetime(item: str) -> datetime:
return datetime.strptime(item.split("checkpoint-")[1].split(".json")[0], _TIME_FORMAT)


def _load_state_dict_from_checkpoint_dir(checkpoint_dir: str) -> Dict[str, Any]:
state_dict: Dict[str, Any] = {}
if not os.path.exists(checkpoint_dir):
Expand Down
58 changes: 47 additions & 11 deletions src/lightning/data/streaming/item_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,21 @@ def generate_intervals(self) -> List[Tuple[int, int]]:
"""Returns a list of tuple describing the indexes intervals of the chunks."""
pass

@abstractmethod
def pre_load_chunk(self, chunk_index: int, chunk_filepath: str) -> None:
"""Logic to load the chunk in background to gain some time."""
pass

@abstractmethod
def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str, begin: int) -> Any:
"""Returns an item loaded from a chunk."""
pass

@abstractmethod
def delete(self, chunk_index: int, chunk_filepath: str) -> None:
"""Delete a chunk from the local filesystem."""
pass


class PyTreeLoader(BaseItemLoader):
"""The Pytree Loader is the default loader of the Cache object."""
Expand All @@ -67,6 +77,9 @@ def generate_intervals(self) -> List[Tuple[int, int]]:
begin += chunk["chunk_size"]
return intervals

def pre_load_chunk(self, chunk_index: int, chunk_filepath: str) -> None:
pass

def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str, begin: int) -> bytes:
offset = (1 + (index - begin) if index >= begin else index + 1) * 4

Expand Down Expand Up @@ -106,6 +119,10 @@ def deserialize(self, raw_item_data: bytes) -> "PyTree":
idx += size
return tree_unflatten(data, self._config["data_spec"])

def delete(self, chunk_index: int, chunk_filepath: str) -> None:
if os.path.exists(chunk_filepath):
os.remove(chunk_filepath)


class TokensLoader(BaseItemLoader):
def __init__(self, block_size: int):
Expand Down Expand Up @@ -146,6 +163,27 @@ def generate_intervals(self) -> List[Tuple[int, int]]:
begin += num_blocks
return intervals

def _load_chunk(self, chunk_index: int, chunk_filepath: str) -> None:
if chunk_index in self._mmaps:
return

chunk = self._chunks[chunk_index]

# Skip the header
# The number of items + the number of offsets (number of items in the chunk + 1)
# multiplied by the header encoding dtype (np.uint32)
offset = (1 + chunk["chunk_size"] + 1) * 4
mmap = np.memmap(chunk_filepath, mode="r", order="C", offset=offset)
self._mmaps[chunk_index] = mmap
self._buffers[chunk_index] = memoryview(mmap) # type: ignore

def pre_load_chunk(self, chunk_index: int, chunk_filepath: str) -> None:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
# This is called within the prepare chunks thread, so we overlap data loading with data reading.
if chunk_filepath not in self._chunk_filepaths:
self._chunk_filepaths[chunk_filepath] = True

self._load_chunk(chunk_index, chunk_filepath)

def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str, begin: int) -> torch.Tensor:
if chunk_filepath in self._chunk_filepaths and not os.path.isfile(chunk_filepath):
del self._chunk_filepaths[chunk_filepath]
Expand All @@ -163,20 +201,18 @@ def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str

self._chunk_filepaths[chunk_filepath] = True

if chunk_index not in self._mmaps:
# TODO: Add deletion and memmap close
chunk = self._chunks[chunk_index]

# Skip the header
# The number of items + the number of offsets (number of items in the chunk + 1)
# multiplied by the header encoding dtype (np.uint32)
offset = (1 + chunk["chunk_size"] + 1) * 4
mmap = np.memmap(chunk_filepath, mode="r", order="C", offset=offset)
self._mmaps[chunk_index] = mmap
self._buffers[chunk_index] = memoryview(mmap) # type: ignore
self._load_chunk(chunk_index, chunk_filepath)

assert self._dtype

buffer: bytes = self._buffers[chunk_index]
offset = self._dtype.itemsize * (index - begin) * self._block_size
return torch.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset)

def delete(self, chunk_index: int, chunk_filepath: str) -> None:
if os.path.exists(chunk_filepath):
if chunk_index in self._buffers:
del self._buffers[chunk_index]
if chunk_index in self._mmaps:
del self._mmaps[chunk_index]
os.remove(chunk_filepath)
Loading