Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Resolve checkpointing for the Streaming Dataset #19123

Merged
merged 6 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions src/lightning/data/streaming/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,15 @@ def filled(self) -> bool:
self._is_done = os.path.exists(os.path.join(self._cache_dir, _INDEX_FILENAME))
return self._is_done

@property
def cache_dir(self) -> str:
return self._cache_dir

@property
def checkpoint_dir(self) -> str:
checkpoint_dir = os.path.join(self._cache_dir, "checkpoints")
return self._try_create(checkpoint_dir)

@property
def checkpoint_rank_dir(self) -> str:
checkpoint_rank_dir = os.path.join(self._cache_dir, "checkpoints", str(self.rank))
return self._try_create(checkpoint_rank_dir)

def _try_create(self, path: str) -> str:
os.makedirs(path, exist_ok=True)
return path
Expand Down
19 changes: 15 additions & 4 deletions src/lightning/data/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def __next__(self) -> Any:
self.index += 1

# Checkpoint based on time
if self.checkpoint_interval and (self.last_time - time()) > self.checkpoint_interval:
if self.checkpoint_interval and (time() - self.last_time) > self.checkpoint_interval:
self._checkpoint(self.chunk_index - 1)

return data
Expand Down Expand Up @@ -298,7 +298,7 @@ def _checkpoint(self, chunk_index: int) -> None:
)

# 4. Move the file to its target position
shutil.move(tmp_checkpoint_path, os.path.join(self.cache.checkpoint_rank_dir, "checkpoint.json"))
shutil.move(tmp_checkpoint_path, os.path.join(self.cache.checkpoint_dir, "checkpoint.json"))

self.last_time = time()

Expand All @@ -316,7 +316,8 @@ def state_dict(self) -> Dict[str, Any]:
if not os.path.exists(self.cache.checkpoint_dir):
return state_dict

state_dict = _load_state_dict_from_checkpoint_dir(self.cache.checkpoint_dir)
# We are reading at the workers level, so we take the dirname
state_dict = _load_state_dict_from_checkpoint_dir(os.path.dirname(self.cache.cache_dir))

if self.distributed_env.world_size > 1:
return _collect_distributed_state_dict(state_dict, self.distributed_env.world_size)
Expand Down Expand Up @@ -401,7 +402,9 @@ def _load_state_dict_from_checkpoint_dir(checkpoint_dir: str) -> Dict[str, Any]:
if not os.path.exists(checkpoint_dir):
return state_dict
for worker_idx in os.listdir(checkpoint_dir):
checkpoint_filepath = os.path.join(checkpoint_dir, str(worker_idx), "checkpoint.json")
if not is_integer(worker_idx):
continue
checkpoint_filepath = os.path.join(checkpoint_dir, str(worker_idx), "checkpoints", "checkpoint.json")
if not os.path.exists(checkpoint_filepath):
state_dict[worker_idx] = {}
else:
Expand Down Expand Up @@ -446,3 +449,11 @@ class RemoteDir:

cache_dir: str
remote: str


def is_integer(value: str) -> bool:
try:
int(value)
return True
except Exception:
return False
2 changes: 1 addition & 1 deletion src/lightning/data/streaming/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class LocalDownloader(Downloader):
@classmethod
def download_file(cls, remote_filepath: str, local_filepath: str) -> None:
if not os.path.exists(remote_filepath):
raise FileNotFoundError("The provided remote_path doesn't exist: {remote_path}")
raise FileNotFoundError(f"The provided remote_path doesn't exist: {remote_filepath}")
if remote_filepath != local_filepath:
shutil.copy(remote_filepath, local_filepath)

Expand Down
123 changes: 98 additions & 25 deletions tests/tests_data/streaming/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,17 @@
from lightning import seed_everything
from lightning.data.streaming import Cache, functions
from lightning.data.streaming.constants import _TIME_FORMAT
from lightning.data.streaming.dataset import StreamingDataset, _should_replace_path, _try_create_cache_dir
from lightning.data.streaming.dataset import (
_INDEX_FILENAME,
Dir,
RemoteDir,
StreamingDataset,
_should_replace_path,
_try_create_cache_dir,
)
from lightning.data.streaming.item_loader import TokensLoader
from lightning.data.streaming.shuffle import FullShuffle, NoShuffle
from lightning.data.utilities.env import _DistributedEnv
from lightning.data.utilities.env import Environment, _DistributedEnv, _WorkerEnv
from torch.utils.data import DataLoader


Expand Down Expand Up @@ -540,12 +547,42 @@ def test_s3_streaming_dataset():
assert dataset.input_dir.path is None


class EmulateS3StreamingDataset(StreamingDataset):
def _create_cache(self, worker_env: _WorkerEnv) -> Cache:
env = Environment(dist_env=self.distributed_env, worker_env=worker_env)

cache_dir = os.path.join(self.input_dir.path, str(env.shard_rank))
os.makedirs(cache_dir, exist_ok=True)

cache = Cache(
input_dir=Dir(cache_dir, self.input_dir.url),
item_loader=self.item_loader,
chunk_bytes=1,
serializers=self.serializers,
)
cache._reader._try_load_config()

if not cache.filled:
raise ValueError(
f"The provided dataset `{self.input_dir}` doesn't contain any {_INDEX_FILENAME} file."
" HINT: Did you successfully optimize a dataset to the provided `input_dir`?"
)

return cache


@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs")
def test_resumable_dataset_single_worker(tmpdir):
def test_resumable_dataset_two_workers(tmpdir):
seed_everything(42)

data_dir = os.path.join(tmpdir, "data")
cache_dir = os.path.join(tmpdir, "cache_dir")

os.makedirs(data_dir)
os.makedirs(cache_dir)

block_size = 20
cache = Cache(input_dir=str(tmpdir), chunk_size=40, item_loader=TokensLoader(block_size))
cache = Cache(input_dir=str(data_dir), chunk_size=40, item_loader=TokensLoader(block_size))

counter = 0
for i in range(100):
Expand All @@ -556,66 +593,100 @@ def test_resumable_dataset_single_worker(tmpdir):
cache.done()
cache.merge()

assert len([f for f in os.listdir(tmpdir) if f.endswith(".bin")]) == 50
assert len([f for f in os.listdir(data_dir) if f.endswith(".bin")]) == 50

dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=True)
dataset = EmulateS3StreamingDataset(
input_dir=RemoteDir(cache_dir, data_dir), item_loader=TokensLoader(block_size), shuffle=True
)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

dataset.current_epoch = 1

assert dataset.state_dict() == {}

dataloader = DataLoader(dataset, num_workers=1, batch_size=2, prefetch_factor=1)
dataloader = DataLoader(dataset, num_workers=2, batch_size=2, prefetch_factor=1)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

dataloader_iter = iter(dataloader)

_ = next(dataloader_iter)
state_dict_0 = dataset.state_dict()

sleep(0.1)

assert state_dict_0["0"]["chunk_index"] == 0
state_dict_0 = dataset.state_dict()

assert sorted(state_dict_0.keys()) == ["0", "1"]

assert state_dict_0["0"]["chunk_index"] == 1
assert state_dict_0["0"]["global_index"] == 2
assert state_dict_0["0"]["index"] == 0

checkpoint_dir = os.path.join(tmpdir, "checkpoints")
assert os.listdir(checkpoint_dir) == ["0"]
assert state_dict_0["1"]["chunk_index"] == 0
assert state_dict_0["1"]["global_index"] == 0
assert state_dict_0["1"]["index"] == 0

_ = next(dataloader_iter)

sleep(0.1)

state_dict_1 = dataset.state_dict()
assert state_dict_1["0"]["chunk_index"] == 2

assert state_dict_1["0"]["chunk_index"] == 1
assert state_dict_1["0"]["global_index"] == 2
assert state_dict_1["0"]["index"] == 0

assert state_dict_1["1"]["chunk_index"] == 1
assert state_dict_1["1"]["global_index"] == 2
assert state_dict_1["1"]["index"] == 0

batch_2 = next(dataloader_iter)

sleep(0.1)

state_dict_2 = dataset.state_dict()
assert state_dict_2["0"]["chunk_index"] == 3

assert state_dict_2["0"]["chunk_index"] == 2
assert state_dict_2["0"]["global_index"] == 4
assert state_dict_2["0"]["index"] == 0

dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=True)
assert state_dict_2["1"]["chunk_index"] == 1
assert state_dict_2["1"]["global_index"] == 2
assert state_dict_2["1"]["index"] == 0

dataset = EmulateS3StreamingDataset(
input_dir=RemoteDir(cache_dir, data_dir),
item_loader=TokensLoader(block_size),
shuffle=True,
)

dataset.load_state_dict(state_dict_1)
dataloader = DataLoader(dataset, num_workers=1, batch_size=2, prefetch_factor=1)
dataloader = DataLoader(dataset, num_workers=2, batch_size=2, prefetch_factor=1)

dataloader_iter = iter(dataloader)
batch_0_restart = next(dataloader_iter)

sleep(0.1)

state_dict_2 = dataset.state_dict()
assert state_dict_2["0"]["chunk_index"] == 3

assert state_dict_2["0"]["chunk_index"] == 2
assert state_dict_2["0"]["global_index"] == 4
assert state_dict_2["0"]["index"] == 0

assert state_dict_2["1"]["chunk_index"] == 1
assert state_dict_2["1"]["global_index"] == 2
assert state_dict_2["1"]["index"] == 0
tchaton marked this conversation as resolved.
Show resolved Hide resolved

assert torch.equal(batch_2, batch_0_restart)


@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs")
def test_dataset_valid_state(tmpdir):
seed_everything(42)

data_dir = os.path.join(tmpdir, "data")
cache_dir = os.path.join(tmpdir, "cache_dir")

os.makedirs(data_dir)
os.makedirs(cache_dir)

block_size = 20
cache = Cache(input_dir=str(tmpdir), chunk_size=40, item_loader=TokensLoader(block_size))
cache = Cache(input_dir=str(data_dir), chunk_size=40, item_loader=TokensLoader(block_size))

counter = 0
for i in range(100):
Expand All @@ -626,12 +697,14 @@ def test_dataset_valid_state(tmpdir):
cache.done()
cache.merge()

dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=False)
dataloader = DataLoader(dataset, num_workers=1, batch_size=2, prefetch_factor=1)
dataset = EmulateS3StreamingDataset(
input_dir=RemoteDir(cache_dir, data_dir), item_loader=TokensLoader(block_size), shuffle=False
)
dataloader = DataLoader(dataset, num_workers=1, batch_size=2)
dataloader_iter = iter(dataloader)
next(dataloader_iter)

sleep(0.1)
sleep(1)

state_dict = dataset.state_dict()

Expand Down Expand Up @@ -666,15 +739,15 @@ def test_dataset_valid_state(tmpdir):
dataset.load_state_dict(state_dict)
with pytest.raises(
ValueError,
match="The provided `input_dir` URL state doesn't match the current one. Found `None` instead of `toto`.", # noqa E501
match=f"The provided `input_dir` URL state doesn't match the current one. Found `{data_dir}` instead of `toto`.", # noqa E501
):
dataset._validate_state_dict()

state_dict["0"]["input_dir_path"] = "toto"
dataset.load_state_dict(state_dict)
with pytest.raises(
ValueError,
match=f"The provided `input_dir` path state doesn't match the current one. Found `{tmpdir}` instead of `toto`.", # noqa E501
match=f"The provided `input_dir` path state doesn't match the current one. Found `{cache_dir}` instead of `toto`.", # noqa E501
):
dataset._validate_state_dict()

Expand Down