diff --git a/src/lightning/data/streaming/__init__.py b/src/lightning/data/streaming/__init__.py index 2282c740336df..dff2bfbdd2c17 100644 --- a/src/lightning/data/streaming/__init__.py +++ b/src/lightning/data/streaming/__init__.py @@ -16,4 +16,11 @@ from lightning.data.streaming.dataset import StreamingDataset from lightning.data.streaming.item_loader import TokensLoader -__all__ = ["Cache", "DataProcessor", "StreamingDataset", "DataTransformRecipe", "DataChunkRecipe", "TokensLoader"] +__all__ = [ + "Cache", + "DataProcessor", + "StreamingDataset", + "DataTransformRecipe", + "DataChunkRecipe", + "TokensLoader", +] diff --git a/src/lightning/data/streaming/cache.py b/src/lightning/data/streaming/cache.py index e5966ba2d8c5e..7327c0ae4f213 100644 --- a/src/lightning/data/streaming/cache.py +++ b/src/lightning/data/streaming/cache.py @@ -94,6 +94,10 @@ def __init__( self._is_done = False self._distributed_env = _DistributedEnv.detect() + @property + def rank(self) -> int: + return self._reader.rank + @property def filled(self) -> bool: """Returns whether the caching phase is done.""" @@ -102,6 +106,20 @@ def filled(self) -> bool: self._is_done = os.path.exists(os.path.join(self._cache_dir, _INDEX_FILENAME)) return self._is_done + @property + def checkpoint_dir(self) -> str: + checkpoint_dir = os.path.join(self._cache_dir, "checkpoints") + if not os.path.exists(checkpoint_dir): + os.makedirs(checkpoint_dir, exist_ok=True) + return checkpoint_dir + + @property + def checkpoint_rank_dir(self) -> str: + checkpoint_rank_dir = os.path.join(self.checkpoint_dir, str(self.rank)) + if not os.path.exists(checkpoint_rank_dir): + os.makedirs(checkpoint_rank_dir, exist_ok=True) + return checkpoint_rank_dir + def __setitem__(self, index: int, data: Any) -> None: """Store an item in the writer.""" self._writer[index] = data diff --git a/src/lightning/data/streaming/constants.py b/src/lightning/data/streaming/constants.py index 65239b14f974a..759c42e07576d 100644 --- a/src/lightning/data/streaming/constants.py +++ b/src/lightning/data/streaming/constants.py @@ -51,3 +51,5 @@ 18: torch.long, 19: torch.bool, } + +_TIME_FORMAT = "%Y-%m-%d_%H-%M-%S.%fZ" diff --git a/src/lightning/data/streaming/dataset.py b/src/lightning/data/streaming/dataset.py index ab92e55f880c9..cf93f568e3f21 100644 --- a/src/lightning/data/streaming/dataset.py +++ b/src/lightning/data/streaming/dataset.py @@ -12,20 +12,34 @@ # limitations under the License. import hashlib +import json import os +import shutil +import sys +import tempfile +from copy import deepcopy from dataclasses import dataclass +from datetime import datetime +from time import time from typing import Any, Dict, List, Optional, Union import numpy as np +import torch from torch.utils.data import IterableDataset from lightning.data.streaming import Cache -from lightning.data.streaming.constants import _DEFAULT_CACHE_DIR, _INDEX_FILENAME, _LIGHTNING_CLOUD_LATEST +from lightning.data.streaming.constants import ( + _DEFAULT_CACHE_DIR, + _INDEX_FILENAME, + _LIGHTNING_CLOUD_LATEST, + _TIME_FORMAT, +) from lightning.data.streaming.item_loader import BaseItemLoader from lightning.data.streaming.sampler import ChunkedIndex from lightning.data.streaming.serializers import Serializer from lightning.data.streaming.shuffle import FullShuffle, NoShuffle, Shuffle from lightning.data.utilities.env import Environment, _DistributedEnv, _WorkerEnv +from lightning.fabric.utilities.distributed import group as _group if _LIGHTNING_CLOUD_LATEST: from lightning_cloud.resolver import Dir, _resolve_dir @@ -42,6 +56,7 @@ def __init__( drop_last: bool = False, seed: int = 42, serializers: Optional[Dict[str, Serializer]] = None, + checkpoint_interval: int = 60 * 5, ) -> None: """The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class. @@ -53,6 +68,7 @@ def __init__( all processes/workers return the same amount of data. seed: Random seed for shuffling. serializers: The serializers used to serialize and deserialize the chunks. + checkpoint_interval: Interval in seconds at which the workers are going to store their own progress. """ super().__init__() @@ -77,6 +93,7 @@ def __init__( self.worker_intervals: List[List[int]] = [] self.current_indexes: List[int] = [] self.chunk_index = 0 + self.global_index = 0 self.index = 0 self.has_triggered_download = False self.min_items_per_replica: Optional[int] = None @@ -84,6 +101,8 @@ def __init__( self.random_state = None self.shuffler: Optional[Shuffle] = None self.serializers = serializers + self.checkpoint_interval = checkpoint_interval + self._state_dict: Optional[Dict[str, Dict[str, Any]]] = None def _create_cache(self, worker_env: _WorkerEnv) -> Cache: env = Environment(dist_env=self.distributed_env, worker_env=worker_env) @@ -109,11 +128,10 @@ def _create_cache(self, worker_env: _WorkerEnv) -> Cache: return cache def _create_shuffler(self, cache: Cache) -> Shuffle: - return ( - FullShuffle(cache, self.seed, self.drop_last) - if self.shuffle - else NoShuffle(cache, self.seed, self.drop_last) - ) + seed = self.seed + if self._state_dict is not None: + seed = self._state_dict[str(cache.rank)]["seed"] + return FullShuffle(cache, seed, self.drop_last) if self.shuffle else NoShuffle(cache, seed, self.drop_last) def __len__(self) -> int: if self.shuffler is None: @@ -126,6 +144,17 @@ def __iter__(self) -> "StreamingDataset": self.cache = self._create_cache(worker_env=self.worker_env) self.shuffler = self._create_shuffler(self.cache) + # Handle restart + if self._state_dict: + self._validate_state_dict() + state = self._state_dict[str(self.cache.rank)] + + # reload indexes + self.chunk_index = state["chunk_index"] + self.global_index = state["global_index"] + self.index = state["index"] + self.current_epoch = state["current_epoch"] + chunks_per_replica, intervals_per_replica = self.shuffler.get_chunks_and_intervals_per_ranks( self.distributed_env, self.current_epoch ) @@ -141,10 +170,26 @@ def __iter__(self) -> "StreamingDataset": self.worker_chunks.append(chunk_index) self.worker_intervals.append(chunk_interval) - self.current_indexes = [] - self.chunk_index = 0 - self.index = 0 + # Handle restart + if self._state_dict: + state = self._state_dict[str(self.cache.rank)] + + # re-generate indexes + interval = self.worker_intervals[self.chunk_index] + current_indexes = np.arange(interval[0], interval[1]) + current_indexes = self.shuffler(current_indexes, self.current_epoch, self.chunk_index) + self.current_indexes = current_indexes[state["index"] :] + + # Bump the chunk_index + self.chunk_index += 1 + else: + self.current_indexes = [] + self.chunk_index = 0 + self.global_index = 0 + self.index = 0 + self.has_triggered_download = False + self.last_time = time() return self @@ -159,7 +204,7 @@ def __getitem__(self, index: Union[ChunkedIndex, int]) -> Any: def __next__(self) -> Any: # Prevent to create more batch on a given process - if self.index >= len(self): + if self.global_index >= len(self): self.current_epoch += 1 raise StopIteration @@ -169,14 +214,19 @@ def __next__(self) -> Any: self.current_epoch += 1 raise StopIteration + # reset index + self.index = 0 + + # Checkpoint when reaching a new chunk + self.checkpoint(self.chunk_index) + interval = self.worker_intervals[self.chunk_index] current_indexes = np.arange(interval[0], interval[1]) assert self.shuffler is not None - self.current_indexes = self.shuffler(current_indexes) - self.chunk_index += 1 + self.current_indexes = self.shuffler(current_indexes, self.current_epoch, self.chunk_index) - last_index = self.chunk_index == len(self.worker_intervals) and len(self.current_indexes) == 1 + self.chunk_index += 1 # Get the first index index = self.current_indexes.pop(0) @@ -188,15 +238,165 @@ def __next__(self) -> Any: chunk_index=self.worker_chunks[self.chunk_index - 1], # We provide the chunks indexes only one the first chunk_indexes=None if self.has_triggered_download else self.worker_chunks, - last_index=last_index, + last_index=(self.chunk_index - 1) == len(self.worker_intervals) and len(self.current_indexes) == 1, ) ) self.has_triggered_download = True + self.global_index += 1 self.index += 1 + # Checkpoint based on time + if (self.last_time - time()) > self.checkpoint_interval: + self.checkpoint(self.chunk_index - 1) + return data + def checkpoint(self, chunk_index: int) -> None: + # Checkpointing isn't supported for windows + if sys.platform == "win32": + return + + assert self.cache + assert self.worker_env + + with tempfile.TemporaryDirectory() as tmpdir: + tmp_checkpoint_path = os.path.join(tmpdir, "checkpoint.json") + with open(tmp_checkpoint_path, "w") as f: + # 1. Write the state to a tempfile + json.dump( + { + "rank": self.cache._reader.rank, + "current_epoch": self.current_epoch, + "input_dir_path": self.input_dir.path, + "input_dir_url": self.input_dir.url, + "item_loader": self.item_loader.state_dict() if self.item_loader else None, + "drop_last": self.drop_last, + "seed": self.seed, + "checkpoint_interval": self.checkpoint_interval, + "chunk_index": chunk_index, + "global_index": self.global_index, + "index": self.index, + "world_size": self.distributed_env.world_size, + "num_workers": self.worker_env.world_size, + "shuffle": self.shuffle, + }, + f, + ) + + # 3. Move the file to avoid corrupted read from the main thread. + now = datetime.now().strftime(_TIME_FORMAT) + checkpoint_path = os.path.join(self.cache.checkpoint_rank_dir, f"checkpoint-{now}.json") + + # 4. Move the file to its target position + shutil.move(tmp_checkpoint_path, checkpoint_path) + + self.last_time = time() + + def state_dict(self) -> Dict[str, Any]: + if self.cache is None: + self.worker_env = _WorkerEnv.detect() + self.cache = self._create_cache(worker_env=self.worker_env) + + state_dict: Dict[str, Any] = {} + worker_env = _WorkerEnv.detect() + if worker_env.world_size == 1: + # 1. Check whether the checkpoint_dir exists + if not os.path.exists(self.cache.checkpoint_dir): + return state_dict + + # 2. Iterate through the workers and read the latest checkpoint + for worker_idx in os.listdir(self.cache.checkpoint_dir): + checkpoints = os.listdir(os.path.join(self.cache.checkpoint_dir, str(worker_idx))) + checkpoints = sorted(checkpoints, key=_string_to_datetime) + + # Load the latest checkpoint for this worker + checkpoint_path = os.path.join(self.cache.checkpoint_dir, str(worker_idx), checkpoints[-1]) + with open(checkpoint_path) as f: + state_dict[worker_idx] = json.load(f) + + _state_dict = deepcopy(state_dict) + + if self.distributed_env.world_size > 1: + # TODO: Move this to fabric. + num_devices = torch.cuda.device_count() or 1 + node_ranks = [] + for index in range(self.distributed_env.world_size): + node_rank = index // num_devices + if node_rank in node_ranks: + continue + state = {} + obj = [_state_dict] + torch.distributed.broadcast_object_list(obj, index, group=_group.WORLD) + state = obj[0] + state_dict.update(**state) + node_ranks.append(node_rank) + else: + raise NotImplementedError("The `state_dict` should be called on the main thread.") + return state_dict + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + if state_dict: + # the state is restored within the workers + self._state_dict = state_dict + + def _validate_state_dict(self) -> None: + assert self._state_dict + assert self.worker_env + assert self.cache + + env = Environment(dist_env=self.distributed_env, worker_env=self.worker_env) + + if env.num_shards != len(self._state_dict): + raise ValueError( + "The provided `state` size doesn't match the number workers world size. " + f"Found `{env.num_shards}` instead of `{len(self._state_dict)}`." + ) + + state: Dict[str, Any] = self._state_dict[str(self.cache.rank)] + + if state["shuffle"] != self.shuffle: + raise ValueError( + "The provided `shuffle` state doesn't match the current one. " + f"Found `{self.shuffle}` instead of `{state['shuffle']}`." + ) + + if state["num_workers"] != self.worker_env.world_size: + raise ValueError( + "The provided `num_workers` state doesn't match the current one. " + f"Found `{self.worker_env.world_size}` instead of `{state['num_workers']}`." + ) + + if state["input_dir_path"] != self.input_dir.path: + raise ValueError( + "The provided `input_dir` path state doesn't match the current one. " + f"Found `{self.input_dir.path}` instead of `{state['input_dir_path']}`." + ) + + if state["input_dir_url"] != self.input_dir.url: + raise ValueError( + "The provided `input_dir` URL state doesn't match the current one. " + f"Found `{self.input_dir.url}` instead of `{state['input_dir_url']}`." + ) + + if state["seed"] != self.seed: + raise ValueError( + "The provided `seed` state doesn't match the current one. " + f"Found `{self.seed}` instead of `{state['seed']}`." + ) + + if self.item_loader and state["item_loader"] != self.item_loader.state_dict(): + raise ValueError( + "The provided `item_loader` state doesn't match the current one. " + f"Found `{self.item_loader.state_dict()}` instead of `{state['item_loader']}`." + ) + + if state["drop_last"] != self.drop_last: + raise ValueError( + "The provided `drop_last` state doesn't match the current one. " + f"Found `{self.drop_last}` instead of `{state['drop_last']}`." + ) + def _try_create_cache_dir(input_dir: str, shard_rank: int = 0) -> Optional[str]: hash_object = hashlib.md5(input_dir.encode()) @@ -209,6 +409,10 @@ def _try_create_cache_dir(input_dir: str, shard_rank: int = 0) -> Optional[str]: return cache_dir +def _string_to_datetime(item: str) -> datetime: + return datetime.strptime(item.split("checkpoint-")[1].split(".json")[0], _TIME_FORMAT) + + @dataclass class RemoteDir: """Holds a remote URL to a directory and a cache directory where the data will be downloaded.""" diff --git a/src/lightning/data/streaming/item_loader.py b/src/lightning/data/streaming/item_loader.py index 1b028369ab388..1ca6082bd771c 100644 --- a/src/lightning/data/streaming/item_loader.py +++ b/src/lightning/data/streaming/item_loader.py @@ -37,6 +37,9 @@ def setup(self, config: Dict, chunks: List, serializers: Dict[str, Serializer]) self._chunks = chunks self._serializers = serializers + def state_dict(self) -> Dict: + return {} + @abstractmethod def generate_intervals(self) -> List[Tuple[int, int]]: """Returns a list of tuple describing the indexes intervals of the chunks.""" @@ -115,6 +118,11 @@ def __init__(self, block_size: int): self._dtype: Optional[torch.dtype] = None self._chunk_filepaths: Dict[str, bool] = {} + def state_dict(self) -> Dict: + return { + "block_size": self._block_size, + } + def setup(self, config: Dict, chunks: List, serializers: Dict[str, Serializer]) -> None: super().setup(config, chunks, serializers) self._dtype = _TORCH_DTYPES_MAPPING[int(config["data_format"][0].split(":")[1])] diff --git a/src/lightning/data/streaming/shuffle.py b/src/lightning/data/streaming/shuffle.py index d389cc3e66bfc..592b25106dddf 100644 --- a/src/lightning/data/streaming/shuffle.py +++ b/src/lightning/data/streaming/shuffle.py @@ -28,7 +28,6 @@ def __init__(self, cache: Cache, seed: int, drop_last: bool): self.cache = cache self.seed = seed self.drop_last = drop_last - self.random_state = None @lru_cache(maxsize=10) def get_len(self, distributed_env: _DistributedEnv, current_epoch: int) -> int: @@ -48,7 +47,7 @@ def get_chunks_and_intervals_per_ranks(self, distributed_env: _DistributedEnv, c pass @abstractmethod - def __call__(self, array: np.ndarray) -> List[int]: + def __call__(self, array: np.ndarray, current_epoch: int, chunk_index: int) -> List[int]: pass @@ -68,7 +67,7 @@ def get_chunks_and_intervals_per_ranks(self, distributed_env: _DistributedEnv, c return chunks_per_ranks, intervals_per_ranks - def __call__(self, array: np.ndarray) -> List[int]: + def __call__(self, array: np.ndarray, current_epoch: int, chunk_index: int) -> List[int]: return array.tolist() @@ -92,14 +91,12 @@ class FullShuffle(Shuffle): @lru_cache(maxsize=10) def get_chunks_and_intervals_per_ranks(self, distributed_env: _DistributedEnv, current_epoch: int) -> Any: - self.random_state = np.random.RandomState(seed=self.seed + current_epoch) # type: ignore - # 1. Get the intervals chunk_intervals = self.cache.get_chunk_intervals() # 2. Shuffle them indexes = range(len(chunk_intervals)) - shuffled_indexes = self.random_state.permutation(indexes) + shuffled_indexes = np.random.RandomState(seed=self.seed + current_epoch).permutation(indexes) shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes] # 3. Compute the items budget of each rank @@ -147,6 +144,5 @@ def get_chunks_and_intervals_per_ranks(self, distributed_env: _DistributedEnv, c return chunks_per_ranks, intervals_per_ranks - def __call__(self, array: np.ndarray) -> List[int]: - assert self.random_state - return self.random_state.permutation(array).tolist() + def __call__(self, array: np.ndarray, current_epoch: int, chunk_index: int) -> List[int]: + return np.random.RandomState(seed=self.seed + current_epoch + chunk_index).permutation(array).tolist() diff --git a/tests/tests_data/streaming/test_data_processor.py b/tests/tests_data/streaming/test_data_processor.py index 610250cb26368..efa886a9c41c4 100644 --- a/tests/tests_data/streaming/test_data_processor.py +++ b/tests/tests_data/streaming/test_data_processor.py @@ -161,7 +161,7 @@ def fn(*_, **__): @pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows") @mock.patch("lightning.data.streaming.data_processor._wait_for_disk_usage_higher_than_threshold") -def test_download_data_target(tmpdir): +def test_download_data_target(wait_for_disk_usage_higher_than_threshold_mock, tmpdir): input_dir = os.path.join(tmpdir, "input_dir") os.makedirs(input_dir, exist_ok=True) @@ -194,6 +194,8 @@ def fn(*_, **__): assert os.listdir(cache_dir) == ["a.txt"] + wait_for_disk_usage_higher_than_threshold_mock.assert_called() + def test_wait_for_disk_usage_higher_than_threshold(): disk_usage_mock = mock.Mock(side_effect=[mock.Mock(free=10e9), mock.Mock(free=10e9), mock.Mock(free=10e11)]) diff --git a/tests/tests_data/streaming/test_dataset.py b/tests/tests_data/streaming/test_dataset.py index 351046c8d8bc0..1a80b351a0523 100644 --- a/tests/tests_data/streaming/test_dataset.py +++ b/tests/tests_data/streaming/test_dataset.py @@ -11,8 +11,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os import sys +from datetime import datetime +from time import sleep from unittest import mock import numpy as np @@ -20,6 +23,7 @@ import torch from lightning import seed_everything from lightning.data.streaming import Cache, functions +from lightning.data.streaming.constants import _TIME_FORMAT from lightning.data.streaming.dataset import StreamingDataset, _try_create_cache_dir from lightning.data.streaming.item_loader import TokensLoader from lightning.data.streaming.shuffle import FullShuffle, NoShuffle @@ -160,7 +164,7 @@ def test_streaming_dataset_distributed_full_shuffle_odd(drop_last, tmpdir): dataset_iter = iter(dataset) assert len(dataset_iter) == 548 process_1_1 = list(dataset_iter) - assert process_1_1[:10] == [785, 788, 782, 783, 789, 787, 786, 781, 784, 780] + assert process_1_1[:10] == [788, 781, 785, 780, 787, 782, 789, 784, 783, 786] assert len(process_1_1) == 548 dataset_2 = StreamingDataset(input_dir=str(tmpdir), shuffle=True, drop_last=drop_last) @@ -171,7 +175,7 @@ def test_streaming_dataset_distributed_full_shuffle_odd(drop_last, tmpdir): dataset_2_iter = iter(dataset_2) assert len(dataset_2_iter) == 548 + int(not drop_last) process_2_1 = list(dataset_2_iter) - assert process_2_1[:10] == [939, 938, 252, 259, 257, 255, 258, 253, 250, 251] + assert process_2_1[:10] == [939, 938, 253, 259, 256, 258, 252, 255, 251, 257] assert len(process_2_1) == 548 + int(not drop_last) assert len([i for i in process_1_1 if i in process_2_1]) == 0 @@ -200,7 +204,7 @@ def test_streaming_dataset_distributed_full_shuffle_even(drop_last, tmpdir): dataset_iter = iter(dataset) assert len(dataset_iter) == 611 process_1_1 = list(dataset_iter) - assert process_1_1[:10] == [185, 184, 182, 189, 187, 181, 183, 180, 186, 188] + assert process_1_1[:10] == [188, 181, 185, 180, 187, 182, 189, 184, 183, 186] assert len(process_1_1) == 611 dataset_2 = StreamingDataset(input_dir=str(tmpdir), shuffle=True, drop_last=drop_last) @@ -211,9 +215,8 @@ def test_streaming_dataset_distributed_full_shuffle_even(drop_last, tmpdir): dataset_2_iter = iter(dataset_2) assert len(dataset_2_iter) == 611 process_2_1 = list(dataset_2_iter) - assert process_2_1[:10] == [813, 815, 816, 812, 818, 811, 817, 814, 819, 277] + assert process_2_1[:10] == [818, 812, 816, 811, 819, 813, 815, 814, 817, 273] assert len(process_2_1) == 611 - assert len([i for i in process_1_1 if i in process_2_1]) == 0 @@ -527,3 +530,209 @@ def test_s3_streaming_dataset(): dataset = StreamingDataset(input_dir="s3://pl-flash-data/optimized_tiny_imagenet") assert dataset.input_dir.url == "s3://pl-flash-data/optimized_tiny_imagenet" assert dataset.input_dir.path is None + + +@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs") +def test_resumable_dataset_single_worker(tmpdir): + seed_everything(42) + + block_size = 20 + cache = Cache(input_dir=str(tmpdir), chunk_size=40, item_loader=TokensLoader(block_size)) + + counter = 0 + for i in range(100): + text_ids = torch.arange(counter, counter + 20).to(torch.int) + cache[i] = text_ids + counter += 20 + + cache.done() + cache.merge() + + assert len([f for f in os.listdir(tmpdir) if f.endswith(".bin")]) == 50 + + dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=True) + + dataset.current_epoch = 1 + + assert dataset.state_dict() == {} + + dataloader = DataLoader(dataset, num_workers=1, batch_size=2, prefetch_factor=1) + + dataloader_iter = iter(dataloader) + + _ = next(dataloader_iter) + state_dict_0 = dataset.state_dict() + + sleep(0.1) + + assert state_dict_0["0"]["chunk_index"] == 0 + assert state_dict_0["0"]["index"] == 0 + + checkpoint_dir = os.path.join(tmpdir, "checkpoints") + assert os.listdir(checkpoint_dir) == ["0"] + _ = next(dataloader_iter) + + sleep(0.1) + + state_dict_1 = dataset.state_dict() + assert state_dict_1["0"]["chunk_index"] == 2 + assert state_dict_1["0"]["index"] == 0 + + batch_2 = next(dataloader_iter) + + sleep(0.1) + + state_dict_2 = dataset.state_dict() + assert state_dict_2["0"]["chunk_index"] == 3 + assert state_dict_2["0"]["index"] == 0 + + dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=True) + dataset.load_state_dict(state_dict_1) + dataloader = DataLoader(dataset, num_workers=1, batch_size=2, prefetch_factor=1) + + dataloader_iter = iter(dataloader) + batch_0_restart = next(dataloader_iter) + + sleep(0.1) + + state_dict_2 = dataset.state_dict() + assert state_dict_2["0"]["chunk_index"] == 3 + assert state_dict_2["0"]["index"] == 0 + + assert torch.equal(batch_2, batch_0_restart) + + +@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs") +def test_dataset_valid_state(tmpdir): + seed_everything(42) + + block_size = 20 + cache = Cache(input_dir=str(tmpdir), chunk_size=40, item_loader=TokensLoader(block_size)) + + counter = 0 + for i in range(100): + text_ids = torch.arange(counter, counter + 20).to(torch.int) + cache[i] = text_ids + counter += 20 + + cache.done() + cache.merge() + + dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=False) + dataloader = DataLoader(dataset, num_workers=1, batch_size=2, prefetch_factor=1) + dataloader_iter = iter(dataloader) + next(dataloader_iter) + + sleep(0.1) + + state_dict = dataset.state_dict() + + dataset.load_state_dict(state_dict) + dataset._validate_state_dict() + + state_dict["0"]["drop_last"] = True + dataset.load_state_dict(state_dict) + with pytest.raises( + ValueError, + match="The provided `drop_last` state doesn't match the current one. Found `False` instead of `True`.", # noqa E501 + ): + dataset._validate_state_dict() + + state_dict["0"]["item_loader"] = {} + dataset.load_state_dict(state_dict) + with pytest.raises( + ValueError, + match="The provided `item_loader` state doesn't match the current one. Found `{'block_size': 20}` instead of `{}`.", # noqa E501 + ): + dataset._validate_state_dict() + + state_dict["0"]["seed"] = 12 + dataset.load_state_dict(state_dict) + with pytest.raises( + ValueError, + match="The provided `seed` state doesn't match the current one. Found `42` instead of `12`.", # noqa E501 + ): + dataset._validate_state_dict() + + state_dict["0"]["input_dir_url"] = "toto" + dataset.load_state_dict(state_dict) + with pytest.raises( + ValueError, + match="The provided `input_dir` URL state doesn't match the current one. Found `None` instead of `toto`.", # noqa E501 + ): + dataset._validate_state_dict() + + state_dict["0"]["input_dir_path"] = "toto" + dataset.load_state_dict(state_dict) + with pytest.raises( + ValueError, + match=f"The provided `input_dir` path state doesn't match the current one. Found `{tmpdir}` instead of `toto`.", # noqa E501 + ): + dataset._validate_state_dict() + + state_dict["0"]["num_workers"] = "8" + dataset.load_state_dict(state_dict) + with pytest.raises( + ValueError, + match=f"The provided `num_workers` state doesn't match the current one. Found `1` instead of `8`.", # noqa E501 + ): + dataset._validate_state_dict() + + state_dict["0"]["shuffle"] = True + dataset.load_state_dict(state_dict) + with pytest.raises( + ValueError, + match=f"The provided `shuffle` state doesn't match the current one. Found `False` instead of `True`.", # noqa E501 + ): + dataset._validate_state_dict() + + +def test_resumable_dataset_distributed_state_dict(tmpdir): + seed_everything(42) + + block_size = 20 + cache = Cache(input_dir=str(tmpdir), chunk_size=40, item_loader=TokensLoader(block_size)) + + counter = 0 + for i in range(100): + text_ids = torch.arange(counter, counter + 20).to(torch.int) + cache[i] = text_ids + counter += 20 + + cache.done() + cache.merge() + + assert len([f for f in os.listdir(tmpdir) if f.endswith(".bin")]) == 50 + + dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=False) + dataset.distributed_env = _DistributedEnv(world_size=16, global_rank=0) + + # used to create the cache + iter(dataset) + os.makedirs(dataset.cache.checkpoint_dir, exist_ok=True) + + for i in range(4): + now = datetime.now().strftime(_TIME_FORMAT) + checkpoint_rank_dir = os.path.join(dataset.cache.checkpoint_dir, str(i)) + os.makedirs(checkpoint_rank_dir, exist_ok=True) + checkpoint_path = os.path.join(checkpoint_rank_dir, f"checkpoint-{now}.json") + with open(checkpoint_path, "w") as f: + json.dump({}, f) + + torch_mock = mock.MagicMock() + torch_mock.cuda.device_count.return_value = 4 + + state_list = [{} for _ in range(4)] + for i in range(16): + state_list[i // 4].update({str(i): {}}) + + def broadcast_object_list(obj, src, **kwargs): + assert src in [0, 4, 8, 12] + obj[0] = state_list.pop(0) + + torch_mock.distributed.broadcast_object_list = broadcast_object_list + + with mock.patch("lightning.data.streaming.dataset.torch", torch_mock): + state_dict = dataset.state_dict() + + assert len(state_dict) == 16