Skip to content

Commit

Permalink
Fix: Resolve checkpointing for the Streaming Dataset (#19123)
Browse files Browse the repository at this point in the history
Co-authored-by: thomas <[email protected]>
(cherry picked from commit 7bd7577)
  • Loading branch information
tchaton authored and lantiga committed Dec 20, 2023
1 parent ed95d55 commit 1e08816
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 35 deletions.
9 changes: 4 additions & 5 deletions src/lightning/data/streaming/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,15 @@ def filled(self) -> bool:
self._is_done = os.path.exists(os.path.join(self._cache_dir, _INDEX_FILENAME))
return self._is_done

@property
def cache_dir(self) -> str:
return self._cache_dir

@property
def checkpoint_dir(self) -> str:
checkpoint_dir = os.path.join(self._cache_dir, "checkpoints")
return self._try_create(checkpoint_dir)

@property
def checkpoint_rank_dir(self) -> str:
checkpoint_rank_dir = os.path.join(self._cache_dir, "checkpoints", str(self.rank))
return self._try_create(checkpoint_rank_dir)

def _try_create(self, path: str) -> str:
os.makedirs(path, exist_ok=True)
return path
Expand Down
19 changes: 15 additions & 4 deletions src/lightning/data/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def __next__(self) -> Any:
self.index += 1

# Checkpoint based on time
if self.checkpoint_interval and (self.last_time - time()) > self.checkpoint_interval:
if self.checkpoint_interval and (time() - self.last_time) > self.checkpoint_interval:
self._checkpoint(self.chunk_index - 1)

return data
Expand Down Expand Up @@ -298,7 +298,7 @@ def _checkpoint(self, chunk_index: int) -> None:
)

# 4. Move the file to its target position
shutil.move(tmp_checkpoint_path, os.path.join(self.cache.checkpoint_rank_dir, "checkpoint.json"))
shutil.move(tmp_checkpoint_path, os.path.join(self.cache.checkpoint_dir, "checkpoint.json"))

self.last_time = time()

Expand All @@ -316,7 +316,8 @@ def state_dict(self) -> Dict[str, Any]:
if not os.path.exists(self.cache.checkpoint_dir):
return state_dict

state_dict = _load_state_dict_from_checkpoint_dir(self.cache.checkpoint_dir)
# We are reading at the workers level, so we take the dirname
state_dict = _load_state_dict_from_checkpoint_dir(os.path.dirname(self.cache.cache_dir))

if self.distributed_env.world_size > 1:
return _collect_distributed_state_dict(state_dict, self.distributed_env.world_size)
Expand Down Expand Up @@ -401,7 +402,9 @@ def _load_state_dict_from_checkpoint_dir(checkpoint_dir: str) -> Dict[str, Any]:
if not os.path.exists(checkpoint_dir):
return state_dict
for worker_idx in os.listdir(checkpoint_dir):
checkpoint_filepath = os.path.join(checkpoint_dir, str(worker_idx), "checkpoint.json")
if not is_integer(worker_idx):
continue
checkpoint_filepath = os.path.join(checkpoint_dir, str(worker_idx), "checkpoints", "checkpoint.json")
if not os.path.exists(checkpoint_filepath):
state_dict[worker_idx] = {}
else:
Expand Down Expand Up @@ -446,3 +449,11 @@ class RemoteDir:

cache_dir: str
remote: str


def is_integer(value: str) -> bool:
try:
int(value)
return True
except Exception:
return False
2 changes: 1 addition & 1 deletion src/lightning/data/streaming/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class LocalDownloader(Downloader):
@classmethod
def download_file(cls, remote_filepath: str, local_filepath: str) -> None:
if not os.path.exists(remote_filepath):
raise FileNotFoundError("The provided remote_path doesn't exist: {remote_path}")
raise FileNotFoundError(f"The provided remote_path doesn't exist: {remote_filepath}")
if remote_filepath != local_filepath:
shutil.copy(remote_filepath, local_filepath)

Expand Down
123 changes: 98 additions & 25 deletions tests/tests_data/streaming/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,17 @@
from lightning import seed_everything
from lightning.data.streaming import Cache, functions
from lightning.data.streaming.constants import _TIME_FORMAT
from lightning.data.streaming.dataset import StreamingDataset, _should_replace_path, _try_create_cache_dir
from lightning.data.streaming.dataset import (
_INDEX_FILENAME,
Dir,
RemoteDir,
StreamingDataset,
_should_replace_path,
_try_create_cache_dir,
)
from lightning.data.streaming.item_loader import TokensLoader
from lightning.data.streaming.shuffle import FullShuffle, NoShuffle
from lightning.data.utilities.env import _DistributedEnv
from lightning.data.utilities.env import Environment, _DistributedEnv, _WorkerEnv
from torch.utils.data import DataLoader


Expand Down Expand Up @@ -540,12 +547,42 @@ def test_s3_streaming_dataset():
assert dataset.input_dir.path is None


class EmulateS3StreamingDataset(StreamingDataset):
def _create_cache(self, worker_env: _WorkerEnv) -> Cache:
env = Environment(dist_env=self.distributed_env, worker_env=worker_env)

cache_dir = os.path.join(self.input_dir.path, str(env.shard_rank))
os.makedirs(cache_dir, exist_ok=True)

cache = Cache(
input_dir=Dir(cache_dir, self.input_dir.url),
item_loader=self.item_loader,
chunk_bytes=1,
serializers=self.serializers,
)
cache._reader._try_load_config()

if not cache.filled:
raise ValueError(
f"The provided dataset `{self.input_dir}` doesn't contain any {_INDEX_FILENAME} file."
" HINT: Did you successfully optimize a dataset to the provided `input_dir`?"
)

return cache


@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs")
def test_resumable_dataset_single_worker(tmpdir):
def test_resumable_dataset_two_workers(tmpdir):
seed_everything(42)

data_dir = os.path.join(tmpdir, "data")
cache_dir = os.path.join(tmpdir, "cache_dir")

os.makedirs(data_dir)
os.makedirs(cache_dir)

block_size = 20
cache = Cache(input_dir=str(tmpdir), chunk_size=40, item_loader=TokensLoader(block_size))
cache = Cache(input_dir=str(data_dir), chunk_size=40, item_loader=TokensLoader(block_size))

counter = 0
for i in range(100):
Expand All @@ -556,66 +593,100 @@ def test_resumable_dataset_single_worker(tmpdir):
cache.done()
cache.merge()

assert len([f for f in os.listdir(tmpdir) if f.endswith(".bin")]) == 50
assert len([f for f in os.listdir(data_dir) if f.endswith(".bin")]) == 50

dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=True)
dataset = EmulateS3StreamingDataset(
input_dir=RemoteDir(cache_dir, data_dir), item_loader=TokensLoader(block_size), shuffle=True
)

dataset.current_epoch = 1

assert dataset.state_dict() == {}

dataloader = DataLoader(dataset, num_workers=1, batch_size=2, prefetch_factor=1)
dataloader = DataLoader(dataset, num_workers=2, batch_size=2, prefetch_factor=1)

dataloader_iter = iter(dataloader)

_ = next(dataloader_iter)
state_dict_0 = dataset.state_dict()

sleep(0.1)

assert state_dict_0["0"]["chunk_index"] == 0
state_dict_0 = dataset.state_dict()

assert sorted(state_dict_0.keys()) == ["0", "1"]

assert state_dict_0["0"]["chunk_index"] == 1
assert state_dict_0["0"]["global_index"] == 2
assert state_dict_0["0"]["index"] == 0

checkpoint_dir = os.path.join(tmpdir, "checkpoints")
assert os.listdir(checkpoint_dir) == ["0"]
assert state_dict_0["1"]["chunk_index"] == 0
assert state_dict_0["1"]["global_index"] == 0
assert state_dict_0["1"]["index"] == 0

_ = next(dataloader_iter)

sleep(0.1)

state_dict_1 = dataset.state_dict()
assert state_dict_1["0"]["chunk_index"] == 2

assert state_dict_1["0"]["chunk_index"] == 1
assert state_dict_1["0"]["global_index"] == 2
assert state_dict_1["0"]["index"] == 0

assert state_dict_1["1"]["chunk_index"] == 1
assert state_dict_1["1"]["global_index"] == 2
assert state_dict_1["1"]["index"] == 0

batch_2 = next(dataloader_iter)

sleep(0.1)

state_dict_2 = dataset.state_dict()
assert state_dict_2["0"]["chunk_index"] == 3

assert state_dict_2["0"]["chunk_index"] == 2
assert state_dict_2["0"]["global_index"] == 4
assert state_dict_2["0"]["index"] == 0

dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=True)
assert state_dict_2["1"]["chunk_index"] == 1
assert state_dict_2["1"]["global_index"] == 2
assert state_dict_2["1"]["index"] == 0

dataset = EmulateS3StreamingDataset(
input_dir=RemoteDir(cache_dir, data_dir),
item_loader=TokensLoader(block_size),
shuffle=True,
)

dataset.load_state_dict(state_dict_1)
dataloader = DataLoader(dataset, num_workers=1, batch_size=2, prefetch_factor=1)
dataloader = DataLoader(dataset, num_workers=2, batch_size=2, prefetch_factor=1)

dataloader_iter = iter(dataloader)
batch_0_restart = next(dataloader_iter)

sleep(0.1)

state_dict_2 = dataset.state_dict()
assert state_dict_2["0"]["chunk_index"] == 3

assert state_dict_2["0"]["chunk_index"] == 2
assert state_dict_2["0"]["global_index"] == 4
assert state_dict_2["0"]["index"] == 0

assert state_dict_2["1"]["chunk_index"] == 1
assert state_dict_2["1"]["global_index"] == 2
assert state_dict_2["1"]["index"] == 0

assert torch.equal(batch_2, batch_0_restart)


@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs")
def test_dataset_valid_state(tmpdir):
seed_everything(42)

data_dir = os.path.join(tmpdir, "data")
cache_dir = os.path.join(tmpdir, "cache_dir")

os.makedirs(data_dir)
os.makedirs(cache_dir)

block_size = 20
cache = Cache(input_dir=str(tmpdir), chunk_size=40, item_loader=TokensLoader(block_size))
cache = Cache(input_dir=str(data_dir), chunk_size=40, item_loader=TokensLoader(block_size))

counter = 0
for i in range(100):
Expand All @@ -626,12 +697,14 @@ def test_dataset_valid_state(tmpdir):
cache.done()
cache.merge()

dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=False)
dataloader = DataLoader(dataset, num_workers=1, batch_size=2, prefetch_factor=1)
dataset = EmulateS3StreamingDataset(
input_dir=RemoteDir(cache_dir, data_dir), item_loader=TokensLoader(block_size), shuffle=False
)
dataloader = DataLoader(dataset, num_workers=1, batch_size=2)
dataloader_iter = iter(dataloader)
next(dataloader_iter)

sleep(0.1)
sleep(1)

state_dict = dataset.state_dict()

Expand Down Expand Up @@ -666,15 +739,15 @@ def test_dataset_valid_state(tmpdir):
dataset.load_state_dict(state_dict)
with pytest.raises(
ValueError,
match="The provided `input_dir` URL state doesn't match the current one. Found `None` instead of `toto`.", # noqa E501
match=f"The provided `input_dir` URL state doesn't match the current one. Found `{data_dir}` instead of `toto`.", # noqa E501
):
dataset._validate_state_dict()

state_dict["0"]["input_dir_path"] = "toto"
dataset.load_state_dict(state_dict)
with pytest.raises(
ValueError,
match=f"The provided `input_dir` path state doesn't match the current one. Found `{tmpdir}` instead of `toto`.", # noqa E501
match=f"The provided `input_dir` path state doesn't match the current one. Found `{cache_dir}` instead of `toto`.", # noqa E501
):
dataset._validate_state_dict()

Expand Down

0 comments on commit 1e08816

Please sign in to comment.