Skip to content

Commit

Permalink
Add fault tolerance Streaming Dataset 2/n (#19052)
Browse files Browse the repository at this point in the history
Co-authored-by: thomas <[email protected]>
  • Loading branch information
tchaton and thomas authored Nov 23, 2023
1 parent bf54a1d commit a6da1e3
Show file tree
Hide file tree
Showing 8 changed files with 111 additions and 84 deletions.
23 changes: 14 additions & 9 deletions src/lightning/data/streaming/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from lightning.data.streaming.sampler import ChunkedIndex
from lightning.data.streaming.serializers import Serializer
from lightning.data.streaming.writer import BinaryWriter
from lightning.data.utilities.env import _DistributedEnv
from lightning.data.utilities.env import _DistributedEnv, _WorkerEnv
from lightning.data.utilities.format import _convert_bytes_to_int

logger = logging.Logger(__name__)
Expand Down Expand Up @@ -93,10 +93,15 @@ def __init__(
)
self._is_done = False
self._distributed_env = _DistributedEnv.detect()
self._rank: Optional[int] = None

@property
def rank(self) -> int:
return self._reader.rank
"""Returns the rank of the Cache."""
if self._rank is None:
self._worker_env = _WorkerEnv.detect()
self._rank = self._distributed_env.global_rank * self._worker_env.world_size + self._worker_env.rank
return self._rank

@property
def filled(self) -> bool:
Expand All @@ -109,16 +114,16 @@ def filled(self) -> bool:
@property
def checkpoint_dir(self) -> str:
checkpoint_dir = os.path.join(self._cache_dir, "checkpoints")
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir, exist_ok=True)
return checkpoint_dir
return self._try_create(checkpoint_dir)

@property
def checkpoint_rank_dir(self) -> str:
checkpoint_rank_dir = os.path.join(self.checkpoint_dir, str(self.rank))
if not os.path.exists(checkpoint_rank_dir):
os.makedirs(checkpoint_rank_dir, exist_ok=True)
return checkpoint_rank_dir
checkpoint_rank_dir = os.path.join(self._cache_dir, "checkpoints", str(self.rank))
return self._try_create(checkpoint_rank_dir)

def _try_create(self, path: str) -> str:
os.makedirs(path, exist_ok=True)
return path

def __setitem__(self, index: int, data: Any) -> None:
"""Store an item in the writer."""
Expand Down
126 changes: 72 additions & 54 deletions src/lightning/data/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,14 @@
import shutil
import sys
import tempfile
from copy import deepcopy
from dataclasses import dataclass
from datetime import datetime
from time import time
from typing import Any, Dict, List, Optional, Union

import numpy as np
import torch
from torch.utils.data import IterableDataset
from torch.utils.data import IterableDataset, get_worker_info

from lightning.data.streaming import Cache
from lightning.data.streaming.constants import (
Expand Down Expand Up @@ -56,7 +55,7 @@ def __init__(
drop_last: bool = False,
seed: int = 42,
serializers: Optional[Dict[str, Serializer]] = None,
checkpoint_interval: int = 60 * 5,
checkpoint_interval: Optional[int] = None,
) -> None:
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class.
Expand Down Expand Up @@ -93,15 +92,19 @@ def __init__(
self.worker_intervals: List[List[int]] = []
self.current_indexes: List[int] = []
self.chunk_index = 0
self.num_chunks: Optional[int] = None
self.global_index = 0
self.index = 0
self.has_triggered_download = False
self.min_items_per_replica: Optional[int] = None
self.current_epoch = 0
self.current_epoch = 1
self.random_state = None
self.shuffler: Optional[Shuffle] = None
self.serializers = serializers
self.checkpoint_interval = checkpoint_interval
if sys.platform == "win32":
if checkpoint_interval is not None:
raise ValueError("The argument `checkpoint_interval` isn't suported on Windows.")
self.checkpoint_interval = checkpoint_interval or 60
self._state_dict: Optional[Dict[str, Dict[str, Any]]] = None

def _create_cache(self, worker_env: _WorkerEnv) -> Cache:
Expand Down Expand Up @@ -170,14 +173,16 @@ def __iter__(self) -> "StreamingDataset":
self.worker_chunks.append(chunk_index)
self.worker_intervals.append(chunk_interval)

self.num_chunks = len(self.worker_chunks)

# Handle restart
if self._state_dict:
state = self._state_dict[str(self.cache.rank)]

# re-generate indexes
interval = self.worker_intervals[self.chunk_index]
current_indexes = np.arange(interval[0], interval[1])
current_indexes = self.shuffler(current_indexes, self.current_epoch, self.chunk_index)
current_indexes = self.shuffler(current_indexes, self.num_chunks, self.current_epoch, self.chunk_index)
self.current_indexes = current_indexes[state["index"] :]

# Bump the chunk_index
Expand Down Expand Up @@ -210,21 +215,22 @@ def __next__(self) -> Any:

# Lazily re-populate the interval to reduce memory usage.
if len(self.current_indexes) == 0:
if self.chunk_index == len(self.worker_intervals):
if self.chunk_index == self.num_chunks:
self.current_epoch += 1
raise StopIteration

# reset index
self.index = 0

# Checkpoint when reaching a new chunk
self.checkpoint(self.chunk_index)
self._checkpoint(self.chunk_index)

interval = self.worker_intervals[self.chunk_index]
current_indexes = np.arange(interval[0], interval[1])

assert self.shuffler is not None
self.current_indexes = self.shuffler(current_indexes, self.current_epoch, self.chunk_index)
assert self.num_chunks is not None
self.current_indexes = self.shuffler(current_indexes, self.num_chunks, self.current_epoch, self.chunk_index)

self.chunk_index += 1

Expand All @@ -238,7 +244,7 @@ def __next__(self) -> Any:
chunk_index=self.worker_chunks[self.chunk_index - 1],
# We provide the chunks indexes only one the first
chunk_indexes=None if self.has_triggered_download else self.worker_chunks,
last_index=(self.chunk_index - 1) == len(self.worker_intervals) and len(self.current_indexes) == 1,
is_last_index=(self.chunk_index - 1) == len(self.worker_intervals) and len(self.current_indexes) == 1,
)
)

Expand All @@ -247,14 +253,16 @@ def __next__(self) -> Any:
self.index += 1

# Checkpoint based on time
if (self.last_time - time()) > self.checkpoint_interval:
self.checkpoint(self.chunk_index - 1)
if self.checkpoint_interval and (self.last_time - time()) > self.checkpoint_interval:
self._checkpoint(self.chunk_index - 1)

return data

def checkpoint(self, chunk_index: int) -> None:
# Checkpointing isn't supported for windows
if sys.platform == "win32":
def _checkpoint(self, chunk_index: int) -> None:
if self.checkpoint_interval is None:
return

if not _is_in_dataloader_worker():
return

assert self.cache
Expand Down Expand Up @@ -284,55 +292,29 @@ def checkpoint(self, chunk_index: int) -> None:
f,
)

# 3. Move the file to avoid corrupted read from the main thread.
now = datetime.now().strftime(_TIME_FORMAT)
checkpoint_path = os.path.join(self.cache.checkpoint_rank_dir, f"checkpoint-{now}.json")

# 4. Move the file to its target position
shutil.move(tmp_checkpoint_path, checkpoint_path)
shutil.move(tmp_checkpoint_path, os.path.join(self.cache.checkpoint_rank_dir, "checkpoint.json"))

self.last_time = time()

def state_dict(self) -> Dict[str, Any]:
if _is_in_dataloader_worker():
raise RuntimeError("The method `state_dict` should only be called in the main process.")

if self.cache is None:
self.worker_env = _WorkerEnv.detect()
self.cache = self._create_cache(worker_env=self.worker_env)

state_dict: Dict[str, Any] = {}
worker_env = _WorkerEnv.detect()
if worker_env.world_size == 1:
# 1. Check whether the checkpoint_dir exists
if not os.path.exists(self.cache.checkpoint_dir):
return state_dict

# 2. Iterate through the workers and read the latest checkpoint
for worker_idx in os.listdir(self.cache.checkpoint_dir):
checkpoints = os.listdir(os.path.join(self.cache.checkpoint_dir, str(worker_idx)))
checkpoints = sorted(checkpoints, key=_string_to_datetime)

# Load the latest checkpoint for this worker
checkpoint_path = os.path.join(self.cache.checkpoint_dir, str(worker_idx), checkpoints[-1])
with open(checkpoint_path) as f:
state_dict[worker_idx] = json.load(f)

_state_dict = deepcopy(state_dict)

if self.distributed_env.world_size > 1:
# TODO: Move this to fabric.
num_devices = torch.cuda.device_count() or 1
node_ranks = []
for index in range(self.distributed_env.world_size):
node_rank = index // num_devices
if node_rank in node_ranks:
continue
state = {}
obj = [_state_dict]
torch.distributed.broadcast_object_list(obj, index, group=_group.WORLD)
state = obj[0]
state_dict.update(**state)
node_ranks.append(node_rank)
else:
raise NotImplementedError("The `state_dict` should be called on the main thread.")

# 1. Check whether the checkpoint_dir exists
if not os.path.exists(self.cache.checkpoint_dir):
return state_dict

state_dict = _load_state_dict_from_checkpoint_dir(self.cache.checkpoint_dir)

if self.distributed_env.world_size > 1:
return _collect_distributed_state_dict(state_dict, self.distributed_env.world_size)
return state_dict

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
Expand Down Expand Up @@ -413,6 +395,42 @@ def _string_to_datetime(item: str) -> datetime:
return datetime.strptime(item.split("checkpoint-")[1].split(".json")[0], _TIME_FORMAT)


def _load_state_dict_from_checkpoint_dir(checkpoint_dir: str) -> Dict[str, Any]:
state_dict: Dict[str, Any] = {}
if not os.path.exists(checkpoint_dir):
return state_dict
for worker_idx in os.listdir(checkpoint_dir):
checkpoint_filepath = os.path.join(checkpoint_dir, str(worker_idx), "checkpoint.json")
if not os.path.exists(checkpoint_filepath):
state_dict[worker_idx] = {}
else:
with open(checkpoint_filepath) as f:
state_dict[worker_idx] = json.load(f)
return state_dict


def _collect_distributed_state_dict(state_dict: Dict[str, Any], world_size: int) -> Dict[str, Any]:
state_dict_out: Dict[str, Any] = {}
# TODO: Move this to fabric to support all accelerators
num_devices = torch.cuda.device_count() or 1
node_ranks = []
for index in range(world_size):
node_rank = index // num_devices
if node_rank in node_ranks:
continue
state = {}
obj = [state_dict]
torch.distributed.broadcast_object_list(obj, index, group=_group.WORLD)
state = obj[0]
state_dict_out.update(**state)
node_ranks.append(node_rank)
return state_dict_out


def _is_in_dataloader_worker() -> bool:
return get_worker_info() is not None


@dataclass
class RemoteDir:
"""Holds a remote URL to a directory and a cache directory where the data will be downloaded."""
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/data/streaming/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def read(self, index: ChunkedIndex) -> Any:
chunk_filepath, begin, _ = self.config[index]
item = self._item_loader.load_item_from_chunk(index.index, index.chunk_index, chunk_filepath, begin)

if index.last_index and self._prepare_thread:
if index.is_last_index and self._prepare_thread:
self._prepare_thread.stop()
self._prepare_thread = None

Expand Down
2 changes: 1 addition & 1 deletion src/lightning/data/streaming/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class ChunkedIndex:
index: int
chunk_index: int
chunk_indexes: Optional[List[int]] = None
last_index: bool = False
is_last_index: bool = False


class CacheBatchSampler:
Expand Down
20 changes: 12 additions & 8 deletions src/lightning/data/streaming/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,10 @@ class TensorSerializer(Serializer):

def __init__(self) -> None:
super().__init__()
self._dtype_to_indice = {v: k for k, v in _TORCH_DTYPES_MAPPING.items()}
self._dtype_to_indices = {v: k for k, v in _TORCH_DTYPES_MAPPING.items()}

def serialize(self, item: torch.Tensor) -> Tuple[bytes, Optional[str]]:
dtype_indice = self._dtype_to_indice[item.dtype]
dtype_indice = self._dtype_to_indices[item.dtype]
data = [np.uint32(dtype_indice).tobytes()]
data.append(np.uint32(len(item.shape)).tobytes())
for dim in item.shape:
Expand Down Expand Up @@ -182,14 +182,14 @@ class NoHeaderTensorSerializer(Serializer):

def __init__(self) -> None:
super().__init__()
self._dtype_to_indice = {v: k for k, v in _TORCH_DTYPES_MAPPING.items()}
self._dtype_to_indices = {v: k for k, v in _TORCH_DTYPES_MAPPING.items()}
self._dtype: Optional[torch.dtype] = None

def setup(self, data_format: str) -> None:
self._dtype = _TORCH_DTYPES_MAPPING[int(data_format.split(":")[1])]

def serialize(self, item: torch.Tensor) -> Tuple[bytes, Optional[str]]:
dtype_indice = self._dtype_to_indice[item.dtype]
dtype_indice = self._dtype_to_indices[item.dtype]
return item.numpy().tobytes(order="C"), f"no_header_tensor:{dtype_indice}"

def deserialize(self, data: bytes) -> torch.Tensor:
Expand All @@ -205,10 +205,10 @@ class NumpySerializer(Serializer):

def __init__(self) -> None:
super().__init__()
self._dtype_to_indice = {v: k for k, v in _NUMPY_DTYPES_MAPPING.items()}
self._dtype_to_indices = {v: k for k, v in _NUMPY_DTYPES_MAPPING.items()}

def serialize(self, item: np.ndarray) -> Tuple[bytes, Optional[str]]:
dtype_indice = self._dtype_to_indice[item.dtype]
dtype_indice = self._dtype_to_indices[item.dtype]
data = [np.uint32(dtype_indice).tobytes()]
data.append(np.uint32(len(item.shape)).tobytes())
for dim in item.shape:
Expand All @@ -221,8 +221,12 @@ def deserialize(self, data: bytes) -> np.ndarray:
dtype = _NUMPY_DTYPES_MAPPING[dtype_indice]
shape_size = np.frombuffer(data[4:8], np.uint32).item()
shape = []
# deserialize the shape header
# Note: The start position of the shape value: 8 (dtype + shape length) + 4 * shape_idx
for shape_idx in range(shape_size):
shape.append(np.frombuffer(data[8 + 4 * shape_idx : 8 + 4 * (shape_idx + 1)], np.uint32).item())

# deserialize the numpy array bytes
tensor = np.frombuffer(data[8 + 4 * (shape_idx + 1) : len(data)], dtype=dtype)
if tensor.shape == shape:
return tensor
Expand All @@ -237,14 +241,14 @@ class NoHeaderNumpySerializer(Serializer):

def __init__(self) -> None:
super().__init__()
self._dtype_to_indice = {v: k for k, v in _NUMPY_DTYPES_MAPPING.items()}
self._dtype_to_indices = {v: k for k, v in _NUMPY_DTYPES_MAPPING.items()}
self._dtype: Optional[np.dtype] = None

def setup(self, data_format: str) -> None:
self._dtype = _NUMPY_DTYPES_MAPPING[int(data_format.split(":")[1])]

def serialize(self, item: np.ndarray) -> Tuple[bytes, Optional[str]]:
dtype_indice: int = self._dtype_to_indice[item.dtype]
dtype_indice: int = self._dtype_to_indices[item.dtype]
return item.tobytes(order="C"), f"no_header_numpy:{dtype_indice}"

def deserialize(self, data: bytes) -> np.ndarray:
Expand Down
Loading

0 comments on commit a6da1e3

Please sign in to comment.