Skip to content

Commit

Permalink
Add direct s3 support to the streaming dataset (#19044)
Browse files Browse the repository at this point in the history
Co-authored-by: thomas <[email protected]>
(cherry picked from commit bc16580)
  • Loading branch information
tchaton authored and lantiga committed Dec 20, 2023
1 parent 5bf93ca commit 9e77488
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 23 deletions.
4 changes: 4 additions & 0 deletions src/lightning/data/streaming/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from pathlib import Path

import torch
from lightning_utilities.core.imports import RequirementCache

_INDEX_FILENAME = "index.json"
_DEFAULT_CHUNK_BYTES = 1 << 26 # 64M B
_DEFAULT_FAST_DEV_RUN_ITEMS = 10
_DEFAULT_CACHE_DIR = os.path.join(Path.home(), ".lightning", "chunks")

# This is required for full pytree serialization / deserialization support
_TORCH_GREATER_EQUAL_2_1_0 = RequirementCache("torch>=2.1.0")
Expand Down
31 changes: 23 additions & 8 deletions src/lightning/data/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,30 @@

import hashlib
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union

import numpy as np
from torch.utils.data import IterableDataset

from lightning.data.streaming import Cache
from lightning.data.streaming.constants import _INDEX_FILENAME, _LIGHTNING_CLOUD_LATEST
from lightning.data.streaming.constants import _DEFAULT_CACHE_DIR, _INDEX_FILENAME, _LIGHTNING_CLOUD_LATEST
from lightning.data.streaming.item_loader import BaseItemLoader
from lightning.data.streaming.sampler import ChunkedIndex
from lightning.data.streaming.serializers import Serializer
from lightning.data.streaming.shuffle import FullShuffle, NoShuffle, Shuffle
from lightning.data.utilities.env import Environment, _DistributedEnv, _WorkerEnv

if _LIGHTNING_CLOUD_LATEST:
from lightning_cloud.resolver import _resolve_dir
from lightning_cloud.resolver import Dir, _resolve_dir


class StreamingDataset(IterableDataset):
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class."""

def __init__(
self,
input_dir: str,
input_dir: Union[str, "RemoteDir"],
item_loader: Optional[BaseItemLoader] = None,
shuffle: bool = False,
drop_last: bool = False,
Expand All @@ -58,6 +59,9 @@ def __init__(
if not isinstance(shuffle, bool):
raise ValueError(f"Shuffle should be a boolean. Found {shuffle}")

if isinstance(input_dir, RemoteDir):
input_dir = Dir(path=input_dir.cache_dir, url=input_dir.remote)

input_dir = _resolve_dir(input_dir)

self.input_dir = input_dir
Expand All @@ -84,9 +88,10 @@ def __init__(
def _create_cache(self, worker_env: _WorkerEnv) -> Cache:
env = Environment(dist_env=self.distributed_env, worker_env=worker_env)

# TODO: Move this to lightning-cloud
if "this_" not in self.input_dir.path:
cache_path = _try_create_cache_dir(input_dir=self.input_dir.path, shard_rank=env.shard_rank)
if self.input_dir.path is None:
cache_path = _try_create_cache_dir(
input_dir=self.input_dir.path if self.input_dir.path else self.input_dir.url, shard_rank=env.shard_rank
)
if cache_path is not None:
self.input_dir.path = cache_path

Expand Down Expand Up @@ -194,9 +199,19 @@ def __next__(self) -> Any:


def _try_create_cache_dir(input_dir: str, shard_rank: int = 0) -> Optional[str]:
if "LIGHTNING_CLUSTER_ID" not in os.environ or "LIGHTNING_CLOUD_PROJECT_ID" not in os.environ:
return None
hash_object = hashlib.md5(input_dir.encode())
if "LIGHTNING_CLUSTER_ID" not in os.environ or "LIGHTNING_CLOUD_PROJECT_ID" not in os.environ:
cache_dir = os.path.join(_DEFAULT_CACHE_DIR, hash_object.hexdigest(), str(shard_rank))
os.makedirs(cache_dir, exist_ok=True)
return cache_dir
cache_dir = os.path.join("/cache", "chunks", hash_object.hexdigest(), str(shard_rank))
os.makedirs(cache_dir, exist_ok=True)
return cache_dir


@dataclass
class RemoteDir:
"""Holds a remote URL to a directory and a cache directory where the data will be downloaded."""

cache_dir: str
remote: str
23 changes: 8 additions & 15 deletions tests/tests_data/streaming/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,6 @@ def test_streaming_dataset(tmpdir, monkeypatch):
assert len(dataloader) == 6


@mock.patch.dict(os.environ, {"LIGHTNING_CLUSTER_ID": "123", "LIGHTNING_CLOUD_PROJECT_ID": "456"})
@mock.patch("lightning.data.streaming.dataset.os.makedirs")
@pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows")
def test_create_cache_dir_in_lightning_cloud(makedirs_mock):
# Locally, we can't actually write to the root filesystem with user privileges, so we need to mock the call
dataset = StreamingDataset("dummy")
with pytest.raises(FileNotFoundError, match="/0` doesn't exist"):
iter(dataset)
makedirs_mock.assert_called()


@pytest.mark.parametrize("drop_last", [False, True])
def test_streaming_dataset_distributed_no_shuffle(drop_last, tmpdir):
seed_everything(42)
Expand Down Expand Up @@ -307,7 +296,7 @@ def test_dataset_cache_recreation(tmpdir):

def test_try_create_cache_dir():
with mock.patch.dict(os.environ, {}, clear=True):
assert _try_create_cache_dir("any") is None
assert os.path.join("chunks", "100b8cad7cf2a56f6df78f171f97a1ec", "0") in _try_create_cache_dir("any")

# the cache dir creating at /cache requires root privileges, so we need to mock `os.makedirs()`
with (
Expand Down Expand Up @@ -493,9 +482,6 @@ def test_dataset_for_text_tokens_distributed_num_workers_end_to_end(tmpdir, monk
monkeypatch.setenv("DATA_OPTIMIZER_CACHE_FOLDER", cache_dir)
monkeypatch.setenv("DATA_OPTIMIZER_DATA_CACHE_FOLDER", cache_dir)

def fn(item):
return torch.arange(item[0], item[0] + 20).to(torch.int)

functions.optimize(
optimize_fn, inputs, output_dir=str(tmpdir), num_workers=2, chunk_size=2, reorder_files=False, num_downloaders=1
)
Expand Down Expand Up @@ -534,3 +520,10 @@ def fn(item):

for batch_idx, batch in enumerate(dataloader):
assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx]


@pytest.mark.skipif(sys.platform == "win32", reason="Not tested on windows and MacOs")
def test_s3_streaming_dataset():
dataset = StreamingDataset(input_dir="s3://pl-flash-data/optimized_tiny_imagenet")
assert dataset.input_dir.url == "s3://pl-flash-data/optimized_tiny_imagenet"
assert dataset.input_dir.path is None

0 comments on commit 9e77488

Please sign in to comment.