Skip to content

Commit

Permalink
Add fault tolerance for the StreamingDataset 1/n (#19049)
Browse files Browse the repository at this point in the history
Co-authored-by: thomas <[email protected]>
(cherry picked from commit 1073276)
  • Loading branch information
tchaton authored and lantiga committed Dec 20, 2023
1 parent 9e77488 commit d48910a
Show file tree
Hide file tree
Showing 8 changed files with 476 additions and 30 deletions.
9 changes: 8 additions & 1 deletion src/lightning/data/streaming/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,11 @@
from lightning.data.streaming.dataset import StreamingDataset
from lightning.data.streaming.item_loader import TokensLoader

__all__ = ["Cache", "DataProcessor", "StreamingDataset", "DataTransformRecipe", "DataChunkRecipe", "TokensLoader"]
__all__ = [
"Cache",
"DataProcessor",
"StreamingDataset",
"DataTransformRecipe",
"DataChunkRecipe",
"TokensLoader",
]
18 changes: 18 additions & 0 deletions src/lightning/data/streaming/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ def __init__(
self._is_done = False
self._distributed_env = _DistributedEnv.detect()

@property
def rank(self) -> int:
return self._reader.rank

@property
def filled(self) -> bool:
"""Returns whether the caching phase is done."""
Expand All @@ -102,6 +106,20 @@ def filled(self) -> bool:
self._is_done = os.path.exists(os.path.join(self._cache_dir, _INDEX_FILENAME))
return self._is_done

@property
def checkpoint_dir(self) -> str:
checkpoint_dir = os.path.join(self._cache_dir, "checkpoints")
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir, exist_ok=True)
return checkpoint_dir

@property
def checkpoint_rank_dir(self) -> str:
checkpoint_rank_dir = os.path.join(self.checkpoint_dir, str(self.rank))
if not os.path.exists(checkpoint_rank_dir):
os.makedirs(checkpoint_rank_dir, exist_ok=True)
return checkpoint_rank_dir

def __setitem__(self, index: int, data: Any) -> None:
"""Store an item in the writer."""
self._writer[index] = data
Expand Down
2 changes: 2 additions & 0 deletions src/lightning/data/streaming/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,5 @@
18: torch.long,
19: torch.bool,
}

_TIME_FORMAT = "%Y-%m-%d_%H-%M-%S.%fZ"
232 changes: 218 additions & 14 deletions src/lightning/data/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,34 @@
# limitations under the License.

import hashlib
import json
import os
import shutil
import sys
import tempfile
from copy import deepcopy
from dataclasses import dataclass
from datetime import datetime
from time import time
from typing import Any, Dict, List, Optional, Union

import numpy as np
import torch
from torch.utils.data import IterableDataset

from lightning.data.streaming import Cache
from lightning.data.streaming.constants import _DEFAULT_CACHE_DIR, _INDEX_FILENAME, _LIGHTNING_CLOUD_LATEST
from lightning.data.streaming.constants import (
_DEFAULT_CACHE_DIR,
_INDEX_FILENAME,
_LIGHTNING_CLOUD_LATEST,
_TIME_FORMAT,
)
from lightning.data.streaming.item_loader import BaseItemLoader
from lightning.data.streaming.sampler import ChunkedIndex
from lightning.data.streaming.serializers import Serializer
from lightning.data.streaming.shuffle import FullShuffle, NoShuffle, Shuffle
from lightning.data.utilities.env import Environment, _DistributedEnv, _WorkerEnv
from lightning.fabric.utilities.distributed import group as _group

if _LIGHTNING_CLOUD_LATEST:
from lightning_cloud.resolver import Dir, _resolve_dir
Expand All @@ -42,6 +56,7 @@ def __init__(
drop_last: bool = False,
seed: int = 42,
serializers: Optional[Dict[str, Serializer]] = None,
checkpoint_interval: int = 60 * 5,
) -> None:
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class.
Expand All @@ -53,6 +68,7 @@ def __init__(
all processes/workers return the same amount of data.
seed: Random seed for shuffling.
serializers: The serializers used to serialize and deserialize the chunks.
checkpoint_interval: Interval in seconds at which the workers are going to store their own progress.
"""
super().__init__()
Expand All @@ -77,13 +93,16 @@ def __init__(
self.worker_intervals: List[List[int]] = []
self.current_indexes: List[int] = []
self.chunk_index = 0
self.global_index = 0
self.index = 0
self.has_triggered_download = False
self.min_items_per_replica: Optional[int] = None
self.current_epoch = 0
self.random_state = None
self.shuffler: Optional[Shuffle] = None
self.serializers = serializers
self.checkpoint_interval = checkpoint_interval
self._state_dict: Optional[Dict[str, Dict[str, Any]]] = None

def _create_cache(self, worker_env: _WorkerEnv) -> Cache:
env = Environment(dist_env=self.distributed_env, worker_env=worker_env)
Expand All @@ -109,11 +128,10 @@ def _create_cache(self, worker_env: _WorkerEnv) -> Cache:
return cache

def _create_shuffler(self, cache: Cache) -> Shuffle:
return (
FullShuffle(cache, self.seed, self.drop_last)
if self.shuffle
else NoShuffle(cache, self.seed, self.drop_last)
)
seed = self.seed
if self._state_dict is not None:
seed = self._state_dict[str(cache.rank)]["seed"]
return FullShuffle(cache, seed, self.drop_last) if self.shuffle else NoShuffle(cache, seed, self.drop_last)

def __len__(self) -> int:
if self.shuffler is None:
Expand All @@ -126,6 +144,17 @@ def __iter__(self) -> "StreamingDataset":
self.cache = self._create_cache(worker_env=self.worker_env)
self.shuffler = self._create_shuffler(self.cache)

# Handle restart
if self._state_dict:
self._validate_state_dict()
state = self._state_dict[str(self.cache.rank)]

# reload indexes
self.chunk_index = state["chunk_index"]
self.global_index = state["global_index"]
self.index = state["index"]
self.current_epoch = state["current_epoch"]

chunks_per_replica, intervals_per_replica = self.shuffler.get_chunks_and_intervals_per_ranks(
self.distributed_env, self.current_epoch
)
Expand All @@ -141,10 +170,26 @@ def __iter__(self) -> "StreamingDataset":
self.worker_chunks.append(chunk_index)
self.worker_intervals.append(chunk_interval)

self.current_indexes = []
self.chunk_index = 0
self.index = 0
# Handle restart
if self._state_dict:
state = self._state_dict[str(self.cache.rank)]

# re-generate indexes
interval = self.worker_intervals[self.chunk_index]
current_indexes = np.arange(interval[0], interval[1])
current_indexes = self.shuffler(current_indexes, self.current_epoch, self.chunk_index)
self.current_indexes = current_indexes[state["index"] :]

# Bump the chunk_index
self.chunk_index += 1
else:
self.current_indexes = []
self.chunk_index = 0
self.global_index = 0
self.index = 0

self.has_triggered_download = False
self.last_time = time()

return self

Expand All @@ -159,7 +204,7 @@ def __getitem__(self, index: Union[ChunkedIndex, int]) -> Any:

def __next__(self) -> Any:
# Prevent to create more batch on a given process
if self.index >= len(self):
if self.global_index >= len(self):
self.current_epoch += 1
raise StopIteration

Expand All @@ -169,14 +214,19 @@ def __next__(self) -> Any:
self.current_epoch += 1
raise StopIteration

# reset index
self.index = 0

# Checkpoint when reaching a new chunk
self.checkpoint(self.chunk_index)

interval = self.worker_intervals[self.chunk_index]
current_indexes = np.arange(interval[0], interval[1])

assert self.shuffler is not None
self.current_indexes = self.shuffler(current_indexes)
self.chunk_index += 1
self.current_indexes = self.shuffler(current_indexes, self.current_epoch, self.chunk_index)

last_index = self.chunk_index == len(self.worker_intervals) and len(self.current_indexes) == 1
self.chunk_index += 1

# Get the first index
index = self.current_indexes.pop(0)
Expand All @@ -188,15 +238,165 @@ def __next__(self) -> Any:
chunk_index=self.worker_chunks[self.chunk_index - 1],
# We provide the chunks indexes only one the first
chunk_indexes=None if self.has_triggered_download else self.worker_chunks,
last_index=last_index,
last_index=(self.chunk_index - 1) == len(self.worker_intervals) and len(self.current_indexes) == 1,
)
)

self.has_triggered_download = True
self.global_index += 1
self.index += 1

# Checkpoint based on time
if (self.last_time - time()) > self.checkpoint_interval:
self.checkpoint(self.chunk_index - 1)

return data

def checkpoint(self, chunk_index: int) -> None:
# Checkpointing isn't supported for windows
if sys.platform == "win32":
return

assert self.cache
assert self.worker_env

with tempfile.TemporaryDirectory() as tmpdir:
tmp_checkpoint_path = os.path.join(tmpdir, "checkpoint.json")
with open(tmp_checkpoint_path, "w") as f:
# 1. Write the state to a tempfile
json.dump(
{
"rank": self.cache._reader.rank,
"current_epoch": self.current_epoch,
"input_dir_path": self.input_dir.path,
"input_dir_url": self.input_dir.url,
"item_loader": self.item_loader.state_dict() if self.item_loader else None,
"drop_last": self.drop_last,
"seed": self.seed,
"checkpoint_interval": self.checkpoint_interval,
"chunk_index": chunk_index,
"global_index": self.global_index,
"index": self.index,
"world_size": self.distributed_env.world_size,
"num_workers": self.worker_env.world_size,
"shuffle": self.shuffle,
},
f,
)

# 3. Move the file to avoid corrupted read from the main thread.
now = datetime.now().strftime(_TIME_FORMAT)
checkpoint_path = os.path.join(self.cache.checkpoint_rank_dir, f"checkpoint-{now}.json")

# 4. Move the file to its target position
shutil.move(tmp_checkpoint_path, checkpoint_path)

self.last_time = time()

def state_dict(self) -> Dict[str, Any]:
if self.cache is None:
self.worker_env = _WorkerEnv.detect()
self.cache = self._create_cache(worker_env=self.worker_env)

state_dict: Dict[str, Any] = {}
worker_env = _WorkerEnv.detect()
if worker_env.world_size == 1:
# 1. Check whether the checkpoint_dir exists
if not os.path.exists(self.cache.checkpoint_dir):
return state_dict

# 2. Iterate through the workers and read the latest checkpoint
for worker_idx in os.listdir(self.cache.checkpoint_dir):
checkpoints = os.listdir(os.path.join(self.cache.checkpoint_dir, str(worker_idx)))
checkpoints = sorted(checkpoints, key=_string_to_datetime)

# Load the latest checkpoint for this worker
checkpoint_path = os.path.join(self.cache.checkpoint_dir, str(worker_idx), checkpoints[-1])
with open(checkpoint_path) as f:
state_dict[worker_idx] = json.load(f)

_state_dict = deepcopy(state_dict)

if self.distributed_env.world_size > 1:
# TODO: Move this to fabric.
num_devices = torch.cuda.device_count() or 1
node_ranks = []
for index in range(self.distributed_env.world_size):
node_rank = index // num_devices
if node_rank in node_ranks:
continue
state = {}
obj = [_state_dict]
torch.distributed.broadcast_object_list(obj, index, group=_group.WORLD)
state = obj[0]
state_dict.update(**state)
node_ranks.append(node_rank)
else:
raise NotImplementedError("The `state_dict` should be called on the main thread.")
return state_dict

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
if state_dict:
# the state is restored within the workers
self._state_dict = state_dict

def _validate_state_dict(self) -> None:
assert self._state_dict
assert self.worker_env
assert self.cache

env = Environment(dist_env=self.distributed_env, worker_env=self.worker_env)

if env.num_shards != len(self._state_dict):
raise ValueError(
"The provided `state` size doesn't match the number workers world size. "
f"Found `{env.num_shards}` instead of `{len(self._state_dict)}`."
)

state: Dict[str, Any] = self._state_dict[str(self.cache.rank)]

if state["shuffle"] != self.shuffle:
raise ValueError(
"The provided `shuffle` state doesn't match the current one. "
f"Found `{self.shuffle}` instead of `{state['shuffle']}`."
)

if state["num_workers"] != self.worker_env.world_size:
raise ValueError(
"The provided `num_workers` state doesn't match the current one. "
f"Found `{self.worker_env.world_size}` instead of `{state['num_workers']}`."
)

if state["input_dir_path"] != self.input_dir.path:
raise ValueError(
"The provided `input_dir` path state doesn't match the current one. "
f"Found `{self.input_dir.path}` instead of `{state['input_dir_path']}`."
)

if state["input_dir_url"] != self.input_dir.url:
raise ValueError(
"The provided `input_dir` URL state doesn't match the current one. "
f"Found `{self.input_dir.url}` instead of `{state['input_dir_url']}`."
)

if state["seed"] != self.seed:
raise ValueError(
"The provided `seed` state doesn't match the current one. "
f"Found `{self.seed}` instead of `{state['seed']}`."
)

if self.item_loader and state["item_loader"] != self.item_loader.state_dict():
raise ValueError(
"The provided `item_loader` state doesn't match the current one. "
f"Found `{self.item_loader.state_dict()}` instead of `{state['item_loader']}`."
)

if state["drop_last"] != self.drop_last:
raise ValueError(
"The provided `drop_last` state doesn't match the current one. "
f"Found `{self.drop_last}` instead of `{state['drop_last']}`."
)


def _try_create_cache_dir(input_dir: str, shard_rank: int = 0) -> Optional[str]:
hash_object = hashlib.md5(input_dir.encode())
Expand All @@ -209,6 +409,10 @@ def _try_create_cache_dir(input_dir: str, shard_rank: int = 0) -> Optional[str]:
return cache_dir


def _string_to_datetime(item: str) -> datetime:
return datetime.strptime(item.split("checkpoint-")[1].split(".json")[0], _TIME_FORMAT)


@dataclass
class RemoteDir:
"""Holds a remote URL to a directory and a cache directory where the data will be downloaded."""
Expand Down
Loading

0 comments on commit d48910a

Please sign in to comment.