Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fault tolerance Streaming Dataset 2/n #19052

Merged
merged 24 commits into from
Nov 23, 2023
Prev Previous commit
Next Next commit
update
thomas authored and thomas committed Nov 22, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit a89c529e3ab003ca1083c2d936f2c737465bdeba
8 changes: 5 additions & 3 deletions src/lightning/data/streaming/cache.py
Original file line number Diff line number Diff line change
@@ -13,7 +13,6 @@

import logging
import os
import uuid
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union

@@ -94,7 +93,10 @@ def __init__(
)
self._is_done = False
self._distributed_env = _DistributedEnv.detect()
self._resume_id = uuid.uuid4()

@property
def rank(self) -> int:
return self._reader.rank

@property
def filled(self) -> bool:
@@ -106,7 +108,7 @@ def filled(self) -> bool:

@property
def resume_folder(self) -> str:
resume_folder = os.path.join(self._cache_dir, self._resume_id)
resume_folder = os.path.join(self._cache_dir, "checkpoints", str(self._reader.rank))
if not os.path.exists(resume_folder):
os.makedirs(resume_folder, exist_ok=True)
return resume_folder
117 changes: 93 additions & 24 deletions src/lightning/data/streaming/dataset.py
Original file line number Diff line number Diff line change
@@ -14,12 +14,14 @@
import hashlib
import json
import os
import shutil
import uuid
from dataclasses import dataclass
from datetime import datetime
from time import time
from typing import Any, Dict, List, Optional, Union

import numpy as np
from lightning_cloud.resolver import Dir as InputDir
from torch.utils.data import IterableDataset

from lightning.data.streaming import Cache
@@ -40,13 +42,13 @@ class StreamingDataset(IterableDataset):

def __init__(
self,
input_dir: Union[str, InputDir],
input_dir: Union[str, "RemoteDir"],
item_loader: Optional[BaseItemLoader] = None,
shuffle: bool = False,
drop_last: bool = False,
seed: int = 42,
serializers: Optional[Dict[str, Serializer]] = None,
checkpoint_progress_interval=None,
checkpoint_interval=None,
) -> None:
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class.

@@ -58,7 +60,7 @@ def __init__(
all processes/workers return the same amount of data.
seed: Random seed for shuffling.
serializers: The serializers used to serialize and deserialize the chunks.
checkpoint_progress_interval: Interval in seconds at which the workers are going to store their own progress.
checkpoint_interval: Interval in seconds at which the workers are going to store their own progress.

"""
super().__init__()
@@ -83,6 +85,7 @@ def __init__(
self.worker_intervals: List[List[int]] = []
self.current_indexes: List[int] = []
self.chunk_index = 0
self.global_index = 0
self.index = 0
self.has_triggered_download = False
self.min_items_per_replica: Optional[int] = None
@@ -91,7 +94,8 @@ def __init__(
self.shuffler: Optional[Shuffle] = None
self.serializers = serializers
self.resume_id = uuid.uuid4()
self.checkpoint_progress_interval = checkpoint_progress_interval or 60 * 5
self.checkpoint_interval = checkpoint_interval or 60 * 5
self._state_dict: Optional[Dict] = None

def _create_cache(self, worker_env: _WorkerEnv) -> Cache:
env = Environment(dist_env=self.distributed_env, worker_env=worker_env)
@@ -117,11 +121,10 @@ def _create_cache(self, worker_env: _WorkerEnv) -> Cache:
return cache

def _create_shuffler(self, cache: Cache) -> Shuffle:
return (
FullShuffle(cache, self.seed, self.drop_last)
if self.shuffle
else NoShuffle(cache, self.seed, self.drop_last)
)
seed = self.seed
if self._state_dict:
seed = self._state_dict[str(self.cache.rank)]["seed"]
return FullShuffle(cache, seed, self.drop_last) if self.shuffle else NoShuffle(cache, seed, self.drop_last)

def __len__(self) -> int:
if self.shuffler is None:
@@ -149,11 +152,26 @@ def __iter__(self) -> "StreamingDataset":
self.worker_chunks.append(chunk_index)
self.worker_intervals.append(chunk_interval)

self.current_indexes = []
self.chunk_index = 0
self.index = 0
self.has_triggered_download = False
self.last_time = time()
# Handle restart
if self._state_dict:
state = self._state_dict[str(self.cache.rank)]
self.chunk_index = state["chunk_index"]
self.global_index = state["global_index"]
self.index = state["index"]
interval = self.worker_intervals[self.chunk_index]
current_indexes = np.arange(interval[0], interval[1])
current_indexes = self.shuffler(current_indexes)
self.current_indexes = current_indexes[state["index"] :]
self.has_triggered_download = False
self.last_time = time()
self.chunk_index += 1
else:
self.current_indexes = []
self.chunk_index = 0
self.global_index = 0
self.index = 0
self.has_triggered_download = False
self.last_time = time()

return self

@@ -168,7 +186,7 @@ def __getitem__(self, index: Union[ChunkedIndex, int]) -> Any:

def __next__(self) -> Any:
# Prevent to create more batch on a given process
if self.index >= len(self):
if self.global_index >= len(self):
self.current_epoch += 1
raise StopIteration

@@ -178,6 +196,12 @@ def __next__(self) -> Any:
self.current_epoch += 1
raise StopIteration

# reset index
self.index = 0

# Checkpoint when reaching a new chunk
self.checkpoint()

interval = self.worker_intervals[self.chunk_index]
current_indexes = np.arange(interval[0], interval[1])

@@ -202,28 +226,73 @@ def __next__(self) -> Any:
)

self.has_triggered_download = True
self.global_index += 1
self.index += 1

if (self.last_time + self.checkpoint_progress_interval) > time():
self.store_progress()
# Checkpoint based on time
if (self.last_time - time()) > self.checkpoint_interval:
self.checkpoint()

return data

def store_progress(self) -> None:
pass
def checkpoint(self) -> None:
import tempfile

with tempfile.NamedTemporaryFile(mode="w+") as tmp:
json.dump(
{
"rank": self.cache._reader.rank,
"current_epoch": self.current_epoch,
"input_dir_path": self.input_dir.path,
"input_dir_url": self.input_dir.url,
"item_loader": self.item_loader.state_dict(),
"drop_last": self.drop_last,
"seed": self.seed,
"checkpoint_interval": self.checkpoint_interval,
"chunk_index": self.chunk_index,
"global_index": self.global_index,
"index": self.index,
},
tmp,
)

tmp.flush()

now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S.%fZ")
checkpoint_path = os.path.join(self.cache.resume_folder, f"checkpoint-{now}.json")
shutil.copyfile(tmp.name, checkpoint_path)

self.last_time = time()

def state_dict(self) -> Dict[_DictKey, Any]:
if self.cache is None:
self.worker_env = _WorkerEnv.detect()
self.cache = self._create_cache(worker_env=self.worker_env)

state_dict = {}
worker_env = _WorkerEnv.detect()
if worker_env.world_size == 1:
for worker_idx, worker_state_file in enumerate(os.listdir(self.cache.resume_folder)):
work_state_filepath = os.path.join(self.cache.resume_folder, worker_state_file)
with open(work_state_filepath) as f:
checkpoint_dir = os.path.join(self.cache._cache_dir, "checkpoints")
if not os.path.exists(checkpoint_dir):
return state_dict
for worker_idx in os.listdir(checkpoint_dir):
checkpoints = os.listdir(os.path.join(checkpoint_dir, str(worker_idx)))
checkpoints = sorted(
checkpoints,
key=lambda item: datetime.strptime(
item.split("checkpoint-")[1].split(".json")[0], "%Y-%m-%d_%H-%M-%S.%fZ"
),
)
checkpoint_path = os.path.join(checkpoint_dir, str(worker_idx), checkpoints[-1])
with open(checkpoint_path) as f:
state_dict[worker_idx] = json.load(f)
else:
raise NotImplementedError("The `state_dict` should be called on the main thread.")
return state_dict

def load_state_dict(self, state_dict: Dict[_DictKey, Any]) -> None:
return self.cache.load_state_dict()
if state_dict:
self._state_dict = state_dict


def _try_create_cache_dir(input_dir: str, shard_rank: int = 0) -> Optional[str]:
8 changes: 8 additions & 0 deletions src/lightning/data/streaming/item_loader.py
Original file line number Diff line number Diff line change
@@ -37,6 +37,9 @@ def setup(self, config: Dict, chunks: List, serializers: Dict[str, Serializer])
self._chunks = chunks
self._serializers = serializers

def state_dict(self) -> Dict:
return {}

@abstractmethod
def generate_intervals(self) -> List[Tuple[int, int]]:
"""Returns a list of tuple describing the indexes intervals of the chunks."""
@@ -115,6 +118,11 @@ def __init__(self, block_size: int):
self._dtype: Optional[torch.dtype] = None
self._chunk_filepaths: Dict[str, bool] = {}

def state_dict(self) -> Dict:
return {
"block_size": self._block_size,
}

def setup(self, config: Dict, chunks: List, serializers: Dict[str, Serializer]) -> None:
super().setup(config, chunks, serializers)
self._dtype = _TORCH_DTYPES_MAPPING[int(config["data_format"][0].split(":")[1])]
53 changes: 49 additions & 4 deletions tests/tests_data/streaming/test_dataset.py
Original file line number Diff line number Diff line change
@@ -13,6 +13,7 @@

import os
import sys
from time import sleep
from unittest import mock

import numpy as np
@@ -522,10 +523,17 @@ def test_dataset_for_text_tokens_distributed_num_workers_end_to_end(tmpdir, monk
assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx]


@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs")
def test_s3_streaming_dataset():
dataset = StreamingDataset(input_dir="s3://pl-flash-data/optimized_tiny_imagenet")
assert dataset.input_dir.url == "s3://pl-flash-data/optimized_tiny_imagenet"
assert dataset.input_dir.path is None


def test_resumable_dataset(tmpdir):
seed_everything(42)

block_size = 10
block_size = 20
cache = Cache(input_dir=str(tmpdir), chunk_size=40, item_loader=TokensLoader(block_size))

counter = 0
@@ -541,12 +549,49 @@ def test_resumable_dataset(tmpdir):

dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=False)

dataloader = DataLoader(dataset, num_workers=2, batch_size=2)
assert dataset.state_dict() == {}

dataloader = DataLoader(dataset, num_workers=1, batch_size=2, prefetch_factor=1)

dataloader_iter = iter(dataloader)

batch_0 = next(dataloader_iter)
_ = next(dataloader_iter)
state_dict_0 = dataset.state_dict()

sleep(0.1)

assert state_dict_0["0"]["chunk_index"] == 0
assert state_dict_0["0"]["index"] == 0

checkpoint_dir = os.path.join(tmpdir, "checkpoints")
assert os.listdir(checkpoint_dir) == ["0"]
_ = next(dataloader_iter)

sleep(0.1)

batch_1 = next(dataloader_iter)
state_dict_1 = dataset.state_dict()
assert state_dict_1["0"]["chunk_index"] == 2
assert state_dict_1["0"]["index"] == 0

batch_2 = next(dataloader_iter)

sleep(0.1)

state_dict_2 = dataset.state_dict()
assert state_dict_2["0"]["chunk_index"] == 3
assert state_dict_2["0"]["index"] == 0

dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=False)
dataset.load_state_dict(state_dict_1)
dataloader = DataLoader(dataset, num_workers=1, batch_size=2, prefetch_factor=1)

dataloader_iter = iter(dataloader)
batch_0_restart = next(dataloader_iter)

sleep(0.1)

state_dict_2 = dataset.state_dict()
assert state_dict_2["0"]["chunk_index"] == 3
assert state_dict_2["0"]["index"] == 0

assert torch.equal(batch_2, batch_0_restart)
You are viewing a condensed version of this merge commit. You can view the full changes here.