Skip to content

Resolve Item Loader bugs #19017

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 37 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
e3fbc39
update
Nov 16, 2023
284a18f
update
Nov 16, 2023
4fcd6e3
update
Nov 16, 2023
91c5ecf
update
Nov 16, 2023
2418e26
update
Nov 16, 2023
43364a7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 16, 2023
1b47645
update
Nov 16, 2023
9d098b3
update
Nov 16, 2023
88e9627
update
Nov 16, 2023
1b0bbaf
Merge branch 'master' into resolve_items_loader_bug
tchaton Nov 16, 2023
ddae27c
update
Nov 16, 2023
ca4e7d6
update
Nov 16, 2023
8a2adac
Merge branch 'resolve_items_loader_bug' of https://github.com/Lightni…
Nov 16, 2023
8c4eb03
update
Nov 16, 2023
32fd505
update
Nov 16, 2023
6f66a27
update
Nov 16, 2023
ee8226c
update
Nov 16, 2023
be08c65
Merge branch 'master' into resolve_items_loader_bug
tchaton Nov 16, 2023
b2a2848
update
Nov 16, 2023
0d424b5
update
Nov 16, 2023
06d53c0
Merge branch 'resolve_items_loader_bug' of https://github.com/Lightni…
Nov 16, 2023
03ed21f
update
Nov 16, 2023
01a41f5
update
Nov 16, 2023
4f2acb7
update
Nov 16, 2023
e4a2059
update
Nov 16, 2023
21fedd4
Merge branch 'master' into resolve_items_loader_bug
tchaton Nov 16, 2023
723b788
update
Nov 16, 2023
1c85956
Merge branch 'resolve_items_loader_bug' of https://github.com/Lightni…
Nov 16, 2023
7ab8dbc
update
Nov 16, 2023
8c3379b
updte
Nov 16, 2023
a1ea1d4
update
Nov 16, 2023
dd27f1b
update
Nov 16, 2023
0a25bfa
update
Nov 16, 2023
bef2736
update
Nov 16, 2023
ccf7628
update
Nov 16, 2023
b553acb
update
Nov 16, 2023
e087bac
update
Nov 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions src/lightning/data/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import hashlib
import os
from typing import Any, Dict, List, Optional, Union
Expand Down Expand Up @@ -84,12 +83,16 @@ def __init__(

def _create_cache(self, worker_env: _WorkerEnv) -> Cache:
env = Environment(dist_env=self.distributed_env, worker_env=worker_env)
cache_path = _try_create_cache_dir(input_dir=self.input_dir.path, shard_rank=env.shard_rank)
cache_dir = copy.deepcopy(self.input_dir)
if cache_path:
cache_dir.path = cache_path

cache = Cache(input_dir=cache_dir, item_loader=self.item_loader, chunk_bytes=1, serializers=self.serializers)
# TODO: Move this to lightning-cloud
if "this_" not in self.input_dir.path:
cache_path = _try_create_cache_dir(input_dir=self.input_dir.path, shard_rank=env.shard_rank)
if cache_path is not None:
self.input_dir.path = cache_path

cache = Cache(
input_dir=self.input_dir, item_loader=self.item_loader, chunk_bytes=1, serializers=self.serializers
)
cache._reader._try_load_config()

if not cache.filled:
Expand Down Expand Up @@ -136,6 +139,7 @@ def __iter__(self) -> "StreamingDataset":
self.current_indexes = []
self.chunk_index = 0
self.index = 0
self.has_triggered_download = False

return self

Expand Down Expand Up @@ -175,6 +179,7 @@ def __next__(self) -> Any:
ChunkedIndex(
index=index,
chunk_index=self.worker_chunks[self.chunk_index - 1],
# We provide the chunks indexes only one the first
chunk_indexes=None if self.has_triggered_download else self.worker_chunks,
)
)
Expand Down
6 changes: 5 additions & 1 deletion src/lightning/data/streaming/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _get_input_dir(inputs: Sequence[Any]) -> str:
if len(indexed_paths) == 0:
raise ValueError(f"The provided item {inputs[0]} didn't contain any filepaths.")

absolute_path = str(Path(indexed_paths[0]).resolve())
absolute_path = str(Path(list(indexed_paths.values())[0]).resolve())

if indexed_paths[0] != absolute_path:
raise ValueError("The provided path should be absolute.")
Expand Down Expand Up @@ -189,6 +189,7 @@ def optimize(
num_nodes: Optional[int] = None,
machine: Optional[str] = None,
num_downloaders: Optional[int] = None,
reorder_files: bool = True,
) -> None:
"""This function converts a dataset into chunks possibly in a distributed way.

Expand All @@ -205,6 +206,8 @@ def optimize(
num_nodes: When doing remote execution, the number of nodes to use.
machine: When doing remote execution, the machine to use.
num_downloaders: The number of downloaders per worker.
reorder_files: By default, reorders the files by file size to distribute work equally among all workers.
Set this to ``False`` if the order in which samples are processed should be preserved.

"""
if not isinstance(inputs, Sequence):
Expand Down Expand Up @@ -235,6 +238,7 @@ def optimize(
num_workers=num_workers or os.cpu_count(),
fast_dev_run=fast_dev_run,
num_downloaders=num_downloaders,
reorder_files=reorder_files,
)
return data_processor.run(
LambdaDataChunkRecipe(
Expand Down
13 changes: 8 additions & 5 deletions src/lightning/data/streaming/item_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ def __init__(self, block_size: int):

super().__init__()
self._block_size = block_size
self._intervals: List[Tuple[int, int]] = []
self._mmaps: Dict[int, np.memmap] = {}
self._buffers: Dict[int, bytes] = {}
self._dtype: Optional[torch.dtype] = None
Expand All @@ -123,16 +122,16 @@ def setup(self, config: Dict, chunks: List, serializers: Dict[str, Serializer])
raise ValueError("The provided chunks isn't properly setup.")

def generate_intervals(self) -> List[Tuple[int, int]]:
intervals = []
begin = 0
end = 0
for chunk in self._chunks:
dim = chunk["dim"]
num_blocks = dim // self._block_size
end += num_blocks
self._intervals.append((begin, end))
intervals.append((begin, end))
begin += num_blocks

return self._intervals
return intervals

def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str, begin: int) -> torch.Tensor:
if chunk_filepath in self._chunk_filepaths and not os.path.isfile(chunk_filepath):
Expand All @@ -149,6 +148,10 @@ def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str
if chunk_index not in self._mmaps:
# TODO: Add deletion and memmap close
chunk = self._chunks[chunk_index]

# Skip the header
# The number of items + the number of offsets (number of items in the chunk + 1)
# multiplied by the header encoding dtype (np.uint32)
offset = (1 + chunk["chunk_size"] + 1) * 4
mmap = np.memmap(chunk_filepath, mode="r", order="C", offset=offset)
self._mmaps[chunk_index] = mmap
Expand All @@ -157,5 +160,5 @@ def load_item_from_chunk(self, index: int, chunk_index: int, chunk_filepath: str
assert self._dtype

buffer: bytes = self._buffers[chunk_index]
offset = self._dtype.itemsize * ((index - begin) if index >= begin else index + 1)
offset = self._dtype.itemsize * (index - begin) * self._block_size
return torch.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset)
2 changes: 0 additions & 2 deletions src/lightning/data/streaming/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ def __init__(
self._rank: Optional[int] = None
self._config: Optional[ChunksConfig] = None
self._prepare_thread: Optional[PrepareChunksThread] = None
self._chunks_index_to_be_downloaded: List[int] = []
self._item_loader = item_loader or PyTreeLoader()
self._last_chunk_index: Optional[int] = None
self._max_cache_size = int(os.getenv("MAX_CACHE_SIZE", max_cache_size))
Expand Down Expand Up @@ -193,7 +192,6 @@ def read(self, index: ChunkedIndex) -> Any:
self._prepare_thread = PrepareChunksThread(self._config, self._max_cache_size)
self._prepare_thread.start()
if index.chunk_indexes:
self._chunks_index_to_be_downloaded.extend(index.chunk_indexes)
self._prepare_thread.download(index.chunk_indexes)

# If the chunk_index isn't already in the download and delete queues, add it.
Expand Down
12 changes: 4 additions & 8 deletions src/lightning/data/streaming/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from abc import ABC, abstractmethod
from functools import lru_cache
from typing import Any, List
from typing import Any, List, Tuple

import numpy as np

Expand Down Expand Up @@ -58,15 +58,11 @@ class NoShuffle(Shuffle):

@lru_cache(maxsize=10)
def get_chunks_and_intervals_per_ranks(self, distributed_env: _DistributedEnv, current_epoch: int) -> Any:
self.random_state = np.random.RandomState(seed=self.seed + current_epoch) # type: ignore
chunk_intervals = self.cache.get_chunk_intervals()
indexes = list(range(len(chunk_intervals)))
shuffled_chunk_intervals = np.asarray(chunk_intervals)[indexes]

chunks_per_ranks: List[List[int]] = [[] for _ in range(distributed_env.world_size)]
intervals_per_ranks: List[List[List[int]]] = [[] for _ in range(distributed_env.world_size)]
for index, (chunk_index, chunk_interval) in enumerate(zip(indexes, shuffled_chunk_intervals)):
replica_index = index % distributed_env.world_size
intervals_per_ranks: List[List[Tuple]] = [[] for _ in range(distributed_env.world_size)]
for chunk_index, chunk_interval in enumerate(chunk_intervals):
replica_index = chunk_index % distributed_env.world_size
chunks_per_ranks[replica_index].append(chunk_index)
intervals_per_ranks[replica_index].append(chunk_interval)

Expand Down
35 changes: 35 additions & 0 deletions tests/tests_data/streaming/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from lightning.data.streaming import Cache
from lightning.data.streaming.dataloader import StreamingDataLoader
from lightning.data.streaming.dataset import StreamingDataset
from lightning.data.streaming.item_loader import TokensLoader
from lightning.data.streaming.serializers import Serializer
from lightning.fabric import Fabric
from lightning.pytorch.demos.boring_classes import RandomDataset
Expand Down Expand Up @@ -276,3 +277,37 @@ def test_custom_serializer(tmpdir):
cache.done()
cache.merge()
assert isinstance(cache[0][0], bytes)


def test_cache_for_text_tokens(tmpdir):
seed_everything(42)

block_size = 1024 + 1
cache = Cache(input_dir=str(tmpdir), chunk_size=block_size * 11, item_loader=TokensLoader(block_size))
text_idxs_list = []

counter = 0
while True:
text_ids = torch.randint(0, 1000, (np.random.randint(0, 1000),)).to(torch.int)
text_idxs_list.append(text_ids)
chunk_filepath = cache._add_item(counter, text_ids)
if chunk_filepath:
break
counter += 1

cache.done()
cache.merge()

assert len(cache) == 10

cache_0 = cache[0]
cache_1 = cache[1]
assert len(cache_0) == block_size
assert len(cache_1) == block_size
assert not torch.equal(cache_0, cache[1])
indices = torch.cat(text_idxs_list, dim=0)
assert torch.equal(cache_0, indices[: len(cache_0)])
assert torch.equal(cache_1, indices[len(cache_0) : len(cache_0) + len(cache_1)])

with pytest.raises(ValueError, match="TokensLoader"):
len(Cache(str(tmpdir), chunk_size=block_size * 11))
Loading