Skip to content

Commit

Permalink
StreamingDataset improve deletion strategy (#19118)
Browse files Browse the repository at this point in the history
Co-authored-by: thomas <[email protected]>
Co-authored-by: Adrian Wälchli <[email protected]>
(cherry picked from commit e6b79d9)
  • Loading branch information
tchaton authored and Borda committed Dec 19, 2023
1 parent 2cb8f45 commit faa836d
Show file tree
Hide file tree
Showing 6 changed files with 226 additions and 121 deletions.
2 changes: 1 addition & 1 deletion src/lightning/data/streaming/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
chunk_size: Optional[int] = None,
chunk_bytes: Optional[Union[int, str]] = None,
item_loader: Optional[BaseItemLoader] = None,
max_cache_size: Union[int, str] = "200GB",
max_cache_size: Union[int, str] = "100GB",
serializers: Optional[Dict[str, Serializer]] = None,
):
"""The Cache enables to optimise dataset format for cloud training. This is done by grouping several elements
Expand Down
7 changes: 7 additions & 0 deletions src/lightning/data/streaming/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,13 @@ def __getitem__(self, index: ChunkedIndex) -> Tuple[str, int, int]:
chunk = self._chunks[index.chunk_index]
return os.path.join(self._cache_dir, chunk["filename"]), *self._intervals[index.chunk_index]

def _get_chunk_index_from_filename(self, chunk_filename: str) -> int:
"""Retrieves the associated chunk_index for a given chunk filename."""
for chunk_index, chunk in enumerate(self._chunks):
if chunk["filename"] == chunk_filename:
return chunk_index
raise ValueError(f"The provided filename doesn't exist {chunk_filename}.")

@classmethod
def load(
cls,
Expand Down
15 changes: 8 additions & 7 deletions src/lightning/data/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import sys
import tempfile
from dataclasses import dataclass
from datetime import datetime
from time import time
from typing import Any, Dict, List, Optional, Union

Expand All @@ -31,7 +30,6 @@
_DEFAULT_CACHE_DIR,
_INDEX_FILENAME,
_LIGHTNING_CLOUD_LATEST,
_TIME_FORMAT,
)
from lightning.data.streaming.item_loader import BaseItemLoader
from lightning.data.streaming.sampler import ChunkedIndex
Expand All @@ -56,6 +54,7 @@ def __init__(
seed: int = 42,
serializers: Optional[Dict[str, Serializer]] = None,
checkpoint_interval: Optional[int] = None,
max_cache_size: Union[int, str] = "100GB",
) -> None:
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class.
Expand All @@ -68,6 +67,7 @@ def __init__(
seed: Random seed for shuffling.
serializers: The serializers used to serialize and deserialize the chunks.
checkpoint_interval: Interval in seconds at which the workers are going to store their own progress.
max_cache_size: The maximum cache size used by the StreamingDataset.
"""
super().__init__()
Expand All @@ -84,6 +84,7 @@ def __init__(
self.shuffle: bool = shuffle
self.drop_last = drop_last
self.seed = seed
self.max_cache_size = max_cache_size

self.cache: Optional[Cache] = None
self.distributed_env = _DistributedEnv.detect()
Expand Down Expand Up @@ -118,7 +119,11 @@ def _create_cache(self, worker_env: _WorkerEnv) -> Cache:
self.input_dir.path = cache_path

cache = Cache(
input_dir=self.input_dir, item_loader=self.item_loader, chunk_bytes=1, serializers=self.serializers
input_dir=self.input_dir,
item_loader=self.item_loader,
chunk_bytes=1,
serializers=self.serializers,
max_cache_size=self.max_cache_size,
)
cache._reader._try_load_config()

Expand Down Expand Up @@ -391,10 +396,6 @@ def _try_create_cache_dir(input_dir: str, shard_rank: int = 0) -> Optional[str]:
return cache_dir


def _string_to_datetime(item: str) -> datetime:
return datetime.strptime(item.split("checkpoint-")[1].split(".json")[0], _TIME_FORMAT)


def _load_state_dict_from_checkpoint_dir(checkpoint_dir: str) -> Dict[str, Any]:
state_dict: Dict[str, Any] = {}
if not os.path.exists(checkpoint_dir):
Expand Down
58 changes: 47 additions & 11 deletions src/lightning/data/streaming/item_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,21 @@ def generate_intervals(self) -> List[Tuple[int, int]]:
"""Returns a list of tuple describing the indexes intervals of the chunks."""
pass

@abstractmethod
def pre_load_chunk(self, chunk_index: int, chunk_filepath: str) -> None:
"""Logic to load the chunk in background to gain some time."""
pass

@abstractmethod
def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str, begin: int) -> Any:
"""Returns an item loaded from a chunk."""
pass

@abstractmethod
def delete(self, chunk_index: int, chunk_filepath: str) -> None:
"""Delete a chunk from the local filesystem."""
pass


class PyTreeLoader(BaseItemLoader):
"""The Pytree Loader is the default loader of the Cache object."""
Expand All @@ -67,6 +77,9 @@ def generate_intervals(self) -> List[Tuple[int, int]]:
begin += chunk["chunk_size"]
return intervals

def pre_load_chunk(self, chunk_index: int, chunk_filepath: str) -> None:
pass

def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str, begin: int) -> bytes:
offset = (1 + (index - begin) if index >= begin else index + 1) * 4

Expand Down Expand Up @@ -106,6 +119,10 @@ def deserialize(self, raw_item_data: bytes) -> "PyTree":
idx += size
return tree_unflatten(data, self._config["data_spec"])

def delete(self, chunk_index: int, chunk_filepath: str) -> None:
if os.path.exists(chunk_filepath):
os.remove(chunk_filepath)


class TokensLoader(BaseItemLoader):
def __init__(self, block_size: int):
Expand Down Expand Up @@ -146,6 +163,27 @@ def generate_intervals(self) -> List[Tuple[int, int]]:
begin += num_blocks
return intervals

def _load_chunk(self, chunk_index: int, chunk_filepath: str) -> None:
if chunk_index in self._mmaps:
return

chunk = self._chunks[chunk_index]

# Skip the header
# The number of items + the number of offsets (number of items in the chunk + 1)
# multiplied by the header encoding dtype (np.uint32)
offset = (1 + chunk["chunk_size"] + 1) * 4
mmap = np.memmap(chunk_filepath, mode="r", order="C", offset=offset)
self._mmaps[chunk_index] = mmap
self._buffers[chunk_index] = memoryview(mmap) # type: ignore

def pre_load_chunk(self, chunk_index: int, chunk_filepath: str) -> None:
# This is called within the prepare chunks thread, so we overlap data loading with data reading.
if chunk_filepath not in self._chunk_filepaths:
self._chunk_filepaths[chunk_filepath] = True

self._load_chunk(chunk_index, chunk_filepath)

def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str, begin: int) -> torch.Tensor:
if chunk_filepath in self._chunk_filepaths and not os.path.isfile(chunk_filepath):
del self._chunk_filepaths[chunk_filepath]
Expand All @@ -163,20 +201,18 @@ def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str

self._chunk_filepaths[chunk_filepath] = True

if chunk_index not in self._mmaps:
# TODO: Add deletion and memmap close
chunk = self._chunks[chunk_index]

# Skip the header
# The number of items + the number of offsets (number of items in the chunk + 1)
# multiplied by the header encoding dtype (np.uint32)
offset = (1 + chunk["chunk_size"] + 1) * 4
mmap = np.memmap(chunk_filepath, mode="r", order="C", offset=offset)
self._mmaps[chunk_index] = mmap
self._buffers[chunk_index] = memoryview(mmap) # type: ignore
self._load_chunk(chunk_index, chunk_filepath)

assert self._dtype

buffer: bytes = self._buffers[chunk_index]
offset = self._dtype.itemsize * (index - begin) * self._block_size
return torch.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset)

def delete(self, chunk_index: int, chunk_filepath: str) -> None:
if os.path.exists(chunk_filepath):
if chunk_index in self._buffers:
del self._buffers[chunk_index]
if chunk_index in self._mmaps:
del self._mmaps[chunk_index]
os.remove(chunk_filepath)
Loading

0 comments on commit faa836d

Please sign in to comment.