diff --git a/src/lightning/data/streaming/cache.py b/src/lightning/data/streaming/cache.py index 349ba824871e8..488c6f7cd6b4d 100644 --- a/src/lightning/data/streaming/cache.py +++ b/src/lightning/data/streaming/cache.py @@ -111,16 +111,15 @@ def filled(self) -> bool: self._is_done = os.path.exists(os.path.join(self._cache_dir, _INDEX_FILENAME)) return self._is_done + @property + def cache_dir(self) -> str: + return self._cache_dir + @property def checkpoint_dir(self) -> str: checkpoint_dir = os.path.join(self._cache_dir, "checkpoints") return self._try_create(checkpoint_dir) - @property - def checkpoint_rank_dir(self) -> str: - checkpoint_rank_dir = os.path.join(self._cache_dir, "checkpoints", str(self.rank)) - return self._try_create(checkpoint_rank_dir) - def _try_create(self, path: str) -> str: os.makedirs(path, exist_ok=True) return path diff --git a/src/lightning/data/streaming/dataset.py b/src/lightning/data/streaming/dataset.py index 1054be85f7dd5..cdf5d20bcdacb 100644 --- a/src/lightning/data/streaming/dataset.py +++ b/src/lightning/data/streaming/dataset.py @@ -258,7 +258,7 @@ def __next__(self) -> Any: self.index += 1 # Checkpoint based on time - if self.checkpoint_interval and (self.last_time - time()) > self.checkpoint_interval: + if self.checkpoint_interval and (time() - self.last_time) > self.checkpoint_interval: self._checkpoint(self.chunk_index - 1) return data @@ -298,7 +298,7 @@ def _checkpoint(self, chunk_index: int) -> None: ) # 4. Move the file to its target position - shutil.move(tmp_checkpoint_path, os.path.join(self.cache.checkpoint_rank_dir, "checkpoint.json")) + shutil.move(tmp_checkpoint_path, os.path.join(self.cache.checkpoint_dir, "checkpoint.json")) self.last_time = time() @@ -316,7 +316,8 @@ def state_dict(self) -> Dict[str, Any]: if not os.path.exists(self.cache.checkpoint_dir): return state_dict - state_dict = _load_state_dict_from_checkpoint_dir(self.cache.checkpoint_dir) + # We are reading at the workers level, so we take the dirname + state_dict = _load_state_dict_from_checkpoint_dir(os.path.dirname(self.cache.cache_dir)) if self.distributed_env.world_size > 1: return _collect_distributed_state_dict(state_dict, self.distributed_env.world_size) @@ -401,7 +402,9 @@ def _load_state_dict_from_checkpoint_dir(checkpoint_dir: str) -> Dict[str, Any]: if not os.path.exists(checkpoint_dir): return state_dict for worker_idx in os.listdir(checkpoint_dir): - checkpoint_filepath = os.path.join(checkpoint_dir, str(worker_idx), "checkpoint.json") + if not is_integer(worker_idx): + continue + checkpoint_filepath = os.path.join(checkpoint_dir, str(worker_idx), "checkpoints", "checkpoint.json") if not os.path.exists(checkpoint_filepath): state_dict[worker_idx] = {} else: @@ -446,3 +449,11 @@ class RemoteDir: cache_dir: str remote: str + + +def is_integer(value: str) -> bool: + try: + int(value) + return True + except Exception: + return False diff --git a/src/lightning/data/streaming/downloader.py b/src/lightning/data/streaming/downloader.py index f6736ddfbace8..4270e7c7b9571 100644 --- a/src/lightning/data/streaming/downloader.py +++ b/src/lightning/data/streaming/downloader.py @@ -65,7 +65,7 @@ class LocalDownloader(Downloader): @classmethod def download_file(cls, remote_filepath: str, local_filepath: str) -> None: if not os.path.exists(remote_filepath): - raise FileNotFoundError("The provided remote_path doesn't exist: {remote_path}") + raise FileNotFoundError(f"The provided remote_path doesn't exist: {remote_filepath}") if remote_filepath != local_filepath: shutil.copy(remote_filepath, local_filepath) diff --git a/tests/tests_data/streaming/test_dataset.py b/tests/tests_data/streaming/test_dataset.py index 317cee748147a..5b5ff424bb5bb 100644 --- a/tests/tests_data/streaming/test_dataset.py +++ b/tests/tests_data/streaming/test_dataset.py @@ -24,10 +24,17 @@ from lightning import seed_everything from lightning.data.streaming import Cache, functions from lightning.data.streaming.constants import _TIME_FORMAT -from lightning.data.streaming.dataset import StreamingDataset, _should_replace_path, _try_create_cache_dir +from lightning.data.streaming.dataset import ( + _INDEX_FILENAME, + Dir, + RemoteDir, + StreamingDataset, + _should_replace_path, + _try_create_cache_dir, +) from lightning.data.streaming.item_loader import TokensLoader from lightning.data.streaming.shuffle import FullShuffle, NoShuffle -from lightning.data.utilities.env import _DistributedEnv +from lightning.data.utilities.env import Environment, _DistributedEnv, _WorkerEnv from torch.utils.data import DataLoader @@ -540,12 +547,42 @@ def test_s3_streaming_dataset(): assert dataset.input_dir.path is None +class EmulateS3StreamingDataset(StreamingDataset): + def _create_cache(self, worker_env: _WorkerEnv) -> Cache: + env = Environment(dist_env=self.distributed_env, worker_env=worker_env) + + cache_dir = os.path.join(self.input_dir.path, str(env.shard_rank)) + os.makedirs(cache_dir, exist_ok=True) + + cache = Cache( + input_dir=Dir(cache_dir, self.input_dir.url), + item_loader=self.item_loader, + chunk_bytes=1, + serializers=self.serializers, + ) + cache._reader._try_load_config() + + if not cache.filled: + raise ValueError( + f"The provided dataset `{self.input_dir}` doesn't contain any {_INDEX_FILENAME} file." + " HINT: Did you successfully optimize a dataset to the provided `input_dir`?" + ) + + return cache + + @pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs") -def test_resumable_dataset_single_worker(tmpdir): +def test_resumable_dataset_two_workers(tmpdir): seed_everything(42) + data_dir = os.path.join(tmpdir, "data") + cache_dir = os.path.join(tmpdir, "cache_dir") + + os.makedirs(data_dir) + os.makedirs(cache_dir) + block_size = 20 - cache = Cache(input_dir=str(tmpdir), chunk_size=40, item_loader=TokensLoader(block_size)) + cache = Cache(input_dir=str(data_dir), chunk_size=40, item_loader=TokensLoader(block_size)) counter = 0 for i in range(100): @@ -556,47 +593,69 @@ def test_resumable_dataset_single_worker(tmpdir): cache.done() cache.merge() - assert len([f for f in os.listdir(tmpdir) if f.endswith(".bin")]) == 50 + assert len([f for f in os.listdir(data_dir) if f.endswith(".bin")]) == 50 - dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=True) + dataset = EmulateS3StreamingDataset( + input_dir=RemoteDir(cache_dir, data_dir), item_loader=TokensLoader(block_size), shuffle=True + ) dataset.current_epoch = 1 - - assert dataset.state_dict() == {} - - dataloader = DataLoader(dataset, num_workers=1, batch_size=2, prefetch_factor=1) + dataloader = DataLoader(dataset, num_workers=2, batch_size=2, prefetch_factor=1) dataloader_iter = iter(dataloader) _ = next(dataloader_iter) - state_dict_0 = dataset.state_dict() sleep(0.1) - assert state_dict_0["0"]["chunk_index"] == 0 + state_dict_0 = dataset.state_dict() + + assert sorted(state_dict_0.keys()) == ["0", "1"] + + assert state_dict_0["0"]["chunk_index"] == 1 + assert state_dict_0["0"]["global_index"] == 2 assert state_dict_0["0"]["index"] == 0 - checkpoint_dir = os.path.join(tmpdir, "checkpoints") - assert os.listdir(checkpoint_dir) == ["0"] + assert state_dict_0["1"]["chunk_index"] == 0 + assert state_dict_0["1"]["global_index"] == 0 + assert state_dict_0["1"]["index"] == 0 + _ = next(dataloader_iter) sleep(0.1) state_dict_1 = dataset.state_dict() - assert state_dict_1["0"]["chunk_index"] == 2 + + assert state_dict_1["0"]["chunk_index"] == 1 + assert state_dict_1["0"]["global_index"] == 2 assert state_dict_1["0"]["index"] == 0 + assert state_dict_1["1"]["chunk_index"] == 1 + assert state_dict_1["1"]["global_index"] == 2 + assert state_dict_1["1"]["index"] == 0 + batch_2 = next(dataloader_iter) sleep(0.1) state_dict_2 = dataset.state_dict() - assert state_dict_2["0"]["chunk_index"] == 3 + + assert state_dict_2["0"]["chunk_index"] == 2 + assert state_dict_2["0"]["global_index"] == 4 assert state_dict_2["0"]["index"] == 0 - dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=True) + assert state_dict_2["1"]["chunk_index"] == 1 + assert state_dict_2["1"]["global_index"] == 2 + assert state_dict_2["1"]["index"] == 0 + + dataset = EmulateS3StreamingDataset( + input_dir=RemoteDir(cache_dir, data_dir), + item_loader=TokensLoader(block_size), + shuffle=True, + ) + dataset.load_state_dict(state_dict_1) - dataloader = DataLoader(dataset, num_workers=1, batch_size=2, prefetch_factor=1) + dataloader = DataLoader(dataset, num_workers=2, batch_size=2, prefetch_factor=1) dataloader_iter = iter(dataloader) batch_0_restart = next(dataloader_iter) @@ -604,9 +663,15 @@ def test_resumable_dataset_single_worker(tmpdir): sleep(0.1) state_dict_2 = dataset.state_dict() - assert state_dict_2["0"]["chunk_index"] == 3 + + assert state_dict_2["0"]["chunk_index"] == 2 + assert state_dict_2["0"]["global_index"] == 4 assert state_dict_2["0"]["index"] == 0 + assert state_dict_2["1"]["chunk_index"] == 1 + assert state_dict_2["1"]["global_index"] == 2 + assert state_dict_2["1"]["index"] == 0 + assert torch.equal(batch_2, batch_0_restart) @@ -614,8 +679,14 @@ def test_resumable_dataset_single_worker(tmpdir): def test_dataset_valid_state(tmpdir): seed_everything(42) + data_dir = os.path.join(tmpdir, "data") + cache_dir = os.path.join(tmpdir, "cache_dir") + + os.makedirs(data_dir) + os.makedirs(cache_dir) + block_size = 20 - cache = Cache(input_dir=str(tmpdir), chunk_size=40, item_loader=TokensLoader(block_size)) + cache = Cache(input_dir=str(data_dir), chunk_size=40, item_loader=TokensLoader(block_size)) counter = 0 for i in range(100): @@ -626,12 +697,14 @@ def test_dataset_valid_state(tmpdir): cache.done() cache.merge() - dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=False) - dataloader = DataLoader(dataset, num_workers=1, batch_size=2, prefetch_factor=1) + dataset = EmulateS3StreamingDataset( + input_dir=RemoteDir(cache_dir, data_dir), item_loader=TokensLoader(block_size), shuffle=False + ) + dataloader = DataLoader(dataset, num_workers=1, batch_size=2) dataloader_iter = iter(dataloader) next(dataloader_iter) - sleep(0.1) + sleep(1) state_dict = dataset.state_dict() @@ -666,7 +739,7 @@ def test_dataset_valid_state(tmpdir): dataset.load_state_dict(state_dict) with pytest.raises( ValueError, - match="The provided `input_dir` URL state doesn't match the current one. Found `None` instead of `toto`.", # noqa E501 + match=f"The provided `input_dir` URL state doesn't match the current one. Found `{data_dir}` instead of `toto`.", # noqa E501 ): dataset._validate_state_dict() @@ -674,7 +747,7 @@ def test_dataset_valid_state(tmpdir): dataset.load_state_dict(state_dict) with pytest.raises( ValueError, - match=f"The provided `input_dir` path state doesn't match the current one. Found `{tmpdir}` instead of `toto`.", # noqa E501 + match=f"The provided `input_dir` path state doesn't match the current one. Found `{cache_dir}` instead of `toto`.", # noqa E501 ): dataset._validate_state_dict()