From 89f6562a99222227d8d88ce2ed844c3f14ae5e79 Mon Sep 17 00:00:00 2001 From: thomas Date: Tue, 21 Nov 2023 12:58:55 +0000 Subject: [PATCH 1/7] update --- src/lightning/data/streaming/dataset.py | 32 +++++++++++++++----- tests/tests_data/streaming/test_dataset.py | 34 +++++++++++++--------- 2 files changed, 45 insertions(+), 21 deletions(-) diff --git a/src/lightning/data/streaming/dataset.py b/src/lightning/data/streaming/dataset.py index 495e36b5daa3f..2e724c2195481 100644 --- a/src/lightning/data/streaming/dataset.py +++ b/src/lightning/data/streaming/dataset.py @@ -13,6 +13,7 @@ import hashlib import os +from dataclasses import dataclass from typing import Any, Dict, List, Optional, Union import numpy as np @@ -26,8 +27,17 @@ from lightning.data.streaming.shuffle import FullShuffle, NoShuffle, Shuffle from lightning.data.utilities.env import Environment, _DistributedEnv, _WorkerEnv + +@dataclass +class RemoteDir: + """Holds a remote URL to a directory and a cache directory where the data will be downloaded.""" + + cache_dir: str = None + remote: str = None + + if _LIGHTNING_CLOUD_LATEST: - from lightning_cloud.resolver import _resolve_dir + from lightning_cloud.resolver import Dir, _resolve_dir class StreamingDataset(IterableDataset): @@ -35,7 +45,7 @@ class StreamingDataset(IterableDataset): def __init__( self, - input_dir: str, + input_dir: Union[str, RemoteDir], item_loader: Optional[BaseItemLoader] = None, shuffle: bool = False, drop_last: bool = False, @@ -58,6 +68,9 @@ def __init__( if not isinstance(shuffle, bool): raise ValueError(f"Shuffle should be a boolean. Found {shuffle}") + if isinstance(input_dir, RemoteDir): + input_dir = Dir(path=input_dir.cache_dir, url=input_dir.remote) + input_dir = _resolve_dir(input_dir) self.input_dir = input_dir @@ -84,9 +97,10 @@ def __init__( def _create_cache(self, worker_env: _WorkerEnv) -> Cache: env = Environment(dist_env=self.distributed_env, worker_env=worker_env) - # TODO: Move this to lightning-cloud - if "this_" not in self.input_dir.path: - cache_path = _try_create_cache_dir(input_dir=self.input_dir.path, shard_rank=env.shard_rank) + if self.input_dir.path is None: + cache_path = _try_create_cache_dir( + input_dir=self.input_dir.path if self.input_dir.path else self.input_dir.url, shard_rank=env.shard_rank + ) if cache_path is not None: self.input_dir.path = cache_path @@ -194,9 +208,13 @@ def __next__(self) -> Any: def _try_create_cache_dir(input_dir: str, shard_rank: int = 0) -> Optional[str]: - if "LIGHTNING_CLUSTER_ID" not in os.environ or "LIGHTNING_CLOUD_PROJECT_ID" not in os.environ: - return None hash_object = hashlib.md5(input_dir.encode()) + if "LIGHTNING_CLUSTER_ID" not in os.environ or "LIGHTNING_CLOUD_PROJECT_ID" not in os.environ: + cache_dir = os.path.join( + os.path.expanduser("~"), ".lightning", "chunks", hash_object.hexdigest(), str(shard_rank) + ) + os.makedirs(cache_dir, exist_ok=True) + return cache_dir cache_dir = os.path.join("/cache", "chunks", hash_object.hexdigest(), str(shard_rank)) os.makedirs(cache_dir, exist_ok=True) return cache_dir diff --git a/tests/tests_data/streaming/test_dataset.py b/tests/tests_data/streaming/test_dataset.py index e5c58661ad758..2e3d91ae6f989 100644 --- a/tests/tests_data/streaming/test_dataset.py +++ b/tests/tests_data/streaming/test_dataset.py @@ -12,7 +12,6 @@ # limitations under the License. import os -import sys from unittest import mock import numpy as np @@ -20,7 +19,7 @@ import torch from lightning import seed_everything from lightning.data.streaming import Cache, functions -from lightning.data.streaming.dataset import StreamingDataset, _try_create_cache_dir +from lightning.data.streaming.dataset import RemoteDir, StreamingDataset, _try_create_cache_dir from lightning.data.streaming.item_loader import TokensLoader from lightning.data.streaming.shuffle import FullShuffle, NoShuffle from lightning.data.utilities.env import _DistributedEnv @@ -55,17 +54,6 @@ def test_streaming_dataset(tmpdir, monkeypatch): assert len(dataloader) == 6 -@mock.patch.dict(os.environ, {"LIGHTNING_CLUSTER_ID": "123", "LIGHTNING_CLOUD_PROJECT_ID": "456"}) -@mock.patch("lightning.data.streaming.dataset.os.makedirs") -@pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows") -def test_create_cache_dir_in_lightning_cloud(makedirs_mock): - # Locally, we can't actually write to the root filesystem with user privileges, so we need to mock the call - dataset = StreamingDataset("dummy") - with pytest.raises(FileNotFoundError, match="/0` doesn't exist"): - iter(dataset) - makedirs_mock.assert_called() - - @pytest.mark.parametrize("drop_last", [False, True]) def test_streaming_dataset_distributed_no_shuffle(drop_last, tmpdir): seed_everything(42) @@ -307,7 +295,7 @@ def test_dataset_cache_recreation(tmpdir): def test_try_create_cache_dir(): with mock.patch.dict(os.environ, {}, clear=True): - assert _try_create_cache_dir("any") is None + assert _try_create_cache_dir("any") == "/Users/thomas/.lightning/chunks/100b8cad7cf2a56f6df78f171f97a1ec/0" # the cache dir creating at /cache requires root privileges, so we need to mock `os.makedirs()` with ( @@ -534,3 +522,21 @@ def fn(item): for batch_idx, batch in enumerate(dataloader): assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx] + + +def test_s3_streaming_dataset(tmpdir): + dataset = StreamingDataset(input_dir="s3://pl-flash-data/optimized_tiny_imagenet") + assert len(dataset) == 1000 + expected = torch.CharTensor([40, 41, 8, 29, 67]).to(dtype=torch.uint8) + generated = dataset[0][0][0][:5] + assert torch.equal(generated, expected) + + dataset = StreamingDataset( + input_dir=RemoteDir(cache_dir=str(tmpdir), remote="s3://pl-flash-data/optimized_tiny_imagenet") + ) + assert len(dataset) == 1000 + expected = torch.CharTensor([40, 41, 8, 29, 67]).to(dtype=torch.uint8) + generated = dataset[0][0][0][:5] + assert torch.equal(generated, expected) + + assert sorted(os.listdir(tmpdir)) == ["chunk-0-0.bin", "index.json"] From 96c3d455a4fd24f9d0e09b3c35ee7454ba79944a Mon Sep 17 00:00:00 2001 From: thomas Date: Tue, 21 Nov 2023 12:59:46 +0000 Subject: [PATCH 2/7] update --- src/lightning/data/streaming/dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lightning/data/streaming/dataset.py b/src/lightning/data/streaming/dataset.py index 2e724c2195481..46ff5e903b786 100644 --- a/src/lightning/data/streaming/dataset.py +++ b/src/lightning/data/streaming/dataset.py @@ -32,8 +32,8 @@ class RemoteDir: """Holds a remote URL to a directory and a cache directory where the data will be downloaded.""" - cache_dir: str = None - remote: str = None + cache_dir: str + remote: str if _LIGHTNING_CLOUD_LATEST: From b83dbe1e03c5c65c57cd83f3a1c28837e18bd1e8 Mon Sep 17 00:00:00 2001 From: thomas Date: Tue, 21 Nov 2023 13:19:08 +0000 Subject: [PATCH 3/7] update --- tests/tests_data/streaming/test_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_data/streaming/test_dataset.py b/tests/tests_data/streaming/test_dataset.py index 2e3d91ae6f989..09408b480981d 100644 --- a/tests/tests_data/streaming/test_dataset.py +++ b/tests/tests_data/streaming/test_dataset.py @@ -295,7 +295,7 @@ def test_dataset_cache_recreation(tmpdir): def test_try_create_cache_dir(): with mock.patch.dict(os.environ, {}, clear=True): - assert _try_create_cache_dir("any") == "/Users/thomas/.lightning/chunks/100b8cad7cf2a56f6df78f171f97a1ec/0" + assert f"{os.sep}".join(["chunks", "100b8cad7cf2a56f6df78f171f97a1ec", "0"]) in _try_create_cache_dir("any") # the cache dir creating at /cache requires root privileges, so we need to mock `os.makedirs()` with ( From 0d6521033f7b7222c257c6816f039defbc7bd033 Mon Sep 17 00:00:00 2001 From: thomas Date: Tue, 21 Nov 2023 13:32:08 +0000 Subject: [PATCH 4/7] update --- .github/workflows/ci-tests-data.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/ci-tests-data.yml b/.github/workflows/ci-tests-data.yml index fccb4394fa999..22a51759813e6 100644 --- a/.github/workflows/ci-tests-data.yml +++ b/.github/workflows/ci-tests-data.yml @@ -94,6 +94,10 @@ jobs: - name: Testing Data working-directory: tests/tests_data + env: + AWS_ACCESS_KEY_ID: ${{ secrets.AWS_PUB_ACCESS_KEY }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_PUB_SECRET_KEY }} + AWS_DEFAULT_REGION: ${{ secrets.AWS_REGION }} # NOTE: do not include coverage report here, see: https://github.com/nedbat/coveragepy/issues/1003 run: | python -m coverage run --source lightning \ From e8f5610760b8c2b576f64dbb67ac6556f8ab877e Mon Sep 17 00:00:00 2001 From: thomas Date: Tue, 21 Nov 2023 15:09:23 +0000 Subject: [PATCH 5/7] update --- .github/workflows/ci-tests-data.yml | 1 + tests/tests_data/streaming/test_dataset.py | 5 ++--- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci-tests-data.yml b/.github/workflows/ci-tests-data.yml index 22a51759813e6..76f8ee39a6e37 100644 --- a/.github/workflows/ci-tests-data.yml +++ b/.github/workflows/ci-tests-data.yml @@ -99,6 +99,7 @@ jobs: AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_PUB_SECRET_KEY }} AWS_DEFAULT_REGION: ${{ secrets.AWS_REGION }} # NOTE: do not include coverage report here, see: https://github.com/nedbat/coveragepy/issues/1003 + timeout-minutes: 10 run: | python -m coverage run --source lightning \ -m pytest -v --timeout=60 --durations=60 diff --git a/tests/tests_data/streaming/test_dataset.py b/tests/tests_data/streaming/test_dataset.py index 09408b480981d..99d01b4abf7cd 100644 --- a/tests/tests_data/streaming/test_dataset.py +++ b/tests/tests_data/streaming/test_dataset.py @@ -12,6 +12,7 @@ # limitations under the License. import os +import sys from unittest import mock import numpy as np @@ -481,9 +482,6 @@ def test_dataset_for_text_tokens_distributed_num_workers_end_to_end(tmpdir, monk monkeypatch.setenv("DATA_OPTIMIZER_CACHE_FOLDER", cache_dir) monkeypatch.setenv("DATA_OPTIMIZER_DATA_CACHE_FOLDER", cache_dir) - def fn(item): - return torch.arange(item[0], item[0] + 20).to(torch.int) - functions.optimize( optimize_fn, inputs, output_dir=str(tmpdir), num_workers=2, chunk_size=2, reorder_files=False, num_downloaders=1 ) @@ -524,6 +522,7 @@ def fn(item): assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx] +@pytest.mark.skipif(sys.platform == "win32" or sys.platform == "darwin", reason="Not tested on windows and MacOs") def test_s3_streaming_dataset(tmpdir): dataset = StreamingDataset(input_dir="s3://pl-flash-data/optimized_tiny_imagenet") assert len(dataset) == 1000 From 4e379977a94195061c05a2490bd8f048d27e2ccc Mon Sep 17 00:00:00 2001 From: thomas Date: Tue, 21 Nov 2023 15:39:18 +0000 Subject: [PATCH 6/7] update --- src/lightning/data/streaming/constants.py | 4 ++++ src/lightning/data/streaming/dataset.py | 25 ++++++++++------------ tests/tests_data/streaming/test_dataset.py | 2 +- 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/src/lightning/data/streaming/constants.py b/src/lightning/data/streaming/constants.py index 19091906a9ff5..65239b14f974a 100644 --- a/src/lightning/data/streaming/constants.py +++ b/src/lightning/data/streaming/constants.py @@ -11,12 +11,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +from pathlib import Path + import torch from lightning_utilities.core.imports import RequirementCache _INDEX_FILENAME = "index.json" _DEFAULT_CHUNK_BYTES = 1 << 26 # 64M B _DEFAULT_FAST_DEV_RUN_ITEMS = 10 +_DEFAULT_CACHE_DIR = os.path.join(Path.home(), ".lightning", "chunks") # This is required for full pytree serialization / deserialization support _TORCH_GREATER_EQUAL_2_1_0 = RequirementCache("torch>=2.1.0") diff --git a/src/lightning/data/streaming/dataset.py b/src/lightning/data/streaming/dataset.py index 46ff5e903b786..ab92e55f880c9 100644 --- a/src/lightning/data/streaming/dataset.py +++ b/src/lightning/data/streaming/dataset.py @@ -20,22 +20,13 @@ from torch.utils.data import IterableDataset from lightning.data.streaming import Cache -from lightning.data.streaming.constants import _INDEX_FILENAME, _LIGHTNING_CLOUD_LATEST +from lightning.data.streaming.constants import _DEFAULT_CACHE_DIR, _INDEX_FILENAME, _LIGHTNING_CLOUD_LATEST from lightning.data.streaming.item_loader import BaseItemLoader from lightning.data.streaming.sampler import ChunkedIndex from lightning.data.streaming.serializers import Serializer from lightning.data.streaming.shuffle import FullShuffle, NoShuffle, Shuffle from lightning.data.utilities.env import Environment, _DistributedEnv, _WorkerEnv - -@dataclass -class RemoteDir: - """Holds a remote URL to a directory and a cache directory where the data will be downloaded.""" - - cache_dir: str - remote: str - - if _LIGHTNING_CLOUD_LATEST: from lightning_cloud.resolver import Dir, _resolve_dir @@ -45,7 +36,7 @@ class StreamingDataset(IterableDataset): def __init__( self, - input_dir: Union[str, RemoteDir], + input_dir: Union[str, "RemoteDir"], item_loader: Optional[BaseItemLoader] = None, shuffle: bool = False, drop_last: bool = False, @@ -210,11 +201,17 @@ def __next__(self) -> Any: def _try_create_cache_dir(input_dir: str, shard_rank: int = 0) -> Optional[str]: hash_object = hashlib.md5(input_dir.encode()) if "LIGHTNING_CLUSTER_ID" not in os.environ or "LIGHTNING_CLOUD_PROJECT_ID" not in os.environ: - cache_dir = os.path.join( - os.path.expanduser("~"), ".lightning", "chunks", hash_object.hexdigest(), str(shard_rank) - ) + cache_dir = os.path.join(_DEFAULT_CACHE_DIR, hash_object.hexdigest(), str(shard_rank)) os.makedirs(cache_dir, exist_ok=True) return cache_dir cache_dir = os.path.join("/cache", "chunks", hash_object.hexdigest(), str(shard_rank)) os.makedirs(cache_dir, exist_ok=True) return cache_dir + + +@dataclass +class RemoteDir: + """Holds a remote URL to a directory and a cache directory where the data will be downloaded.""" + + cache_dir: str + remote: str diff --git a/tests/tests_data/streaming/test_dataset.py b/tests/tests_data/streaming/test_dataset.py index 99d01b4abf7cd..13c2d92b230b4 100644 --- a/tests/tests_data/streaming/test_dataset.py +++ b/tests/tests_data/streaming/test_dataset.py @@ -296,7 +296,7 @@ def test_dataset_cache_recreation(tmpdir): def test_try_create_cache_dir(): with mock.patch.dict(os.environ, {}, clear=True): - assert f"{os.sep}".join(["chunks", "100b8cad7cf2a56f6df78f171f97a1ec", "0"]) in _try_create_cache_dir("any") + assert os.path.join("chunks", "100b8cad7cf2a56f6df78f171f97a1ec", "0") in _try_create_cache_dir("any") # the cache dir creating at /cache requires root privileges, so we need to mock `os.makedirs()` with ( From 6ae9f33920acc4654906c5a5ceb1b4780e3d4e4b Mon Sep 17 00:00:00 2001 From: thomas Date: Tue, 21 Nov 2023 19:59:40 +0000 Subject: [PATCH 7/7] update --- .github/workflows/ci-tests-data.yml | 5 ----- tests/tests_data/streaming/test_dataset.py | 22 +++++----------------- 2 files changed, 5 insertions(+), 22 deletions(-) diff --git a/.github/workflows/ci-tests-data.yml b/.github/workflows/ci-tests-data.yml index 76f8ee39a6e37..fccb4394fa999 100644 --- a/.github/workflows/ci-tests-data.yml +++ b/.github/workflows/ci-tests-data.yml @@ -94,12 +94,7 @@ jobs: - name: Testing Data working-directory: tests/tests_data - env: - AWS_ACCESS_KEY_ID: ${{ secrets.AWS_PUB_ACCESS_KEY }} - AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_PUB_SECRET_KEY }} - AWS_DEFAULT_REGION: ${{ secrets.AWS_REGION }} # NOTE: do not include coverage report here, see: https://github.com/nedbat/coveragepy/issues/1003 - timeout-minutes: 10 run: | python -m coverage run --source lightning \ -m pytest -v --timeout=60 --durations=60 diff --git a/tests/tests_data/streaming/test_dataset.py b/tests/tests_data/streaming/test_dataset.py index 13c2d92b230b4..351046c8d8bc0 100644 --- a/tests/tests_data/streaming/test_dataset.py +++ b/tests/tests_data/streaming/test_dataset.py @@ -20,7 +20,7 @@ import torch from lightning import seed_everything from lightning.data.streaming import Cache, functions -from lightning.data.streaming.dataset import RemoteDir, StreamingDataset, _try_create_cache_dir +from lightning.data.streaming.dataset import StreamingDataset, _try_create_cache_dir from lightning.data.streaming.item_loader import TokensLoader from lightning.data.streaming.shuffle import FullShuffle, NoShuffle from lightning.data.utilities.env import _DistributedEnv @@ -522,20 +522,8 @@ def test_dataset_for_text_tokens_distributed_num_workers_end_to_end(tmpdir, monk assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx] -@pytest.mark.skipif(sys.platform == "win32" or sys.platform == "darwin", reason="Not tested on windows and MacOs") -def test_s3_streaming_dataset(tmpdir): +@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs") +def test_s3_streaming_dataset(): dataset = StreamingDataset(input_dir="s3://pl-flash-data/optimized_tiny_imagenet") - assert len(dataset) == 1000 - expected = torch.CharTensor([40, 41, 8, 29, 67]).to(dtype=torch.uint8) - generated = dataset[0][0][0][:5] - assert torch.equal(generated, expected) - - dataset = StreamingDataset( - input_dir=RemoteDir(cache_dir=str(tmpdir), remote="s3://pl-flash-data/optimized_tiny_imagenet") - ) - assert len(dataset) == 1000 - expected = torch.CharTensor([40, 41, 8, 29, 67]).to(dtype=torch.uint8) - generated = dataset[0][0][0][:5] - assert torch.equal(generated, expected) - - assert sorted(os.listdir(tmpdir)) == ["chunk-0-0.bin", "index.json"] + assert dataset.input_dir.url == "s3://pl-flash-data/optimized_tiny_imagenet" + assert dataset.input_dir.path is None