Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add direct s3 support to the streaming dataset #19044

Merged
merged 8 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/ci-tests-data.yml
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,12 @@ jobs:

- name: Testing Data
working-directory: tests/tests_data
env:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_PUB_ACCESS_KEY }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_PUB_SECRET_KEY }}
AWS_DEFAULT_REGION: ${{ secrets.AWS_REGION }}
tchaton marked this conversation as resolved.
Show resolved Hide resolved
# NOTE: do not include coverage report here, see: https://github.com/nedbat/coveragepy/issues/1003
timeout-minutes: 10
run: |
python -m coverage run --source lightning \
-m pytest -v --timeout=60 --durations=60
Expand Down
4 changes: 4 additions & 0 deletions src/lightning/data/streaming/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from pathlib import Path

import torch
from lightning_utilities.core.imports import RequirementCache

_INDEX_FILENAME = "index.json"
_DEFAULT_CHUNK_BYTES = 1 << 26 # 64M B
_DEFAULT_FAST_DEV_RUN_ITEMS = 10
_DEFAULT_CACHE_DIR = os.path.join(Path.home(), ".lightning", "chunks")

# This is required for full pytree serialization / deserialization support
_TORCH_GREATER_EQUAL_2_1_0 = RequirementCache("torch>=2.1.0")
Expand Down
31 changes: 23 additions & 8 deletions src/lightning/data/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,30 @@

import hashlib
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union

import numpy as np
from torch.utils.data import IterableDataset

from lightning.data.streaming import Cache
from lightning.data.streaming.constants import _INDEX_FILENAME, _LIGHTNING_CLOUD_LATEST
from lightning.data.streaming.constants import _DEFAULT_CACHE_DIR, _INDEX_FILENAME, _LIGHTNING_CLOUD_LATEST
from lightning.data.streaming.item_loader import BaseItemLoader
from lightning.data.streaming.sampler import ChunkedIndex
from lightning.data.streaming.serializers import Serializer
from lightning.data.streaming.shuffle import FullShuffle, NoShuffle, Shuffle
from lightning.data.utilities.env import Environment, _DistributedEnv, _WorkerEnv

if _LIGHTNING_CLOUD_LATEST:
from lightning_cloud.resolver import _resolve_dir
from lightning_cloud.resolver import Dir, _resolve_dir


class StreamingDataset(IterableDataset):
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class."""

def __init__(
self,
input_dir: str,
input_dir: Union[str, "RemoteDir"],
item_loader: Optional[BaseItemLoader] = None,
shuffle: bool = False,
drop_last: bool = False,
Expand All @@ -58,6 +59,9 @@ def __init__(
if not isinstance(shuffle, bool):
raise ValueError(f"Shuffle should be a boolean. Found {shuffle}")

if isinstance(input_dir, RemoteDir):
input_dir = Dir(path=input_dir.cache_dir, url=input_dir.remote)

input_dir = _resolve_dir(input_dir)

self.input_dir = input_dir
Expand All @@ -84,9 +88,10 @@ def __init__(
def _create_cache(self, worker_env: _WorkerEnv) -> Cache:
env = Environment(dist_env=self.distributed_env, worker_env=worker_env)

# TODO: Move this to lightning-cloud
if "this_" not in self.input_dir.path:
cache_path = _try_create_cache_dir(input_dir=self.input_dir.path, shard_rank=env.shard_rank)
if self.input_dir.path is None:
cache_path = _try_create_cache_dir(
input_dir=self.input_dir.path if self.input_dir.path else self.input_dir.url, shard_rank=env.shard_rank
)
if cache_path is not None:
self.input_dir.path = cache_path

Expand Down Expand Up @@ -194,9 +199,19 @@ def __next__(self) -> Any:


def _try_create_cache_dir(input_dir: str, shard_rank: int = 0) -> Optional[str]:
if "LIGHTNING_CLUSTER_ID" not in os.environ or "LIGHTNING_CLOUD_PROJECT_ID" not in os.environ:
return None
hash_object = hashlib.md5(input_dir.encode())
if "LIGHTNING_CLUSTER_ID" not in os.environ or "LIGHTNING_CLOUD_PROJECT_ID" not in os.environ:
cache_dir = os.path.join(_DEFAULT_CACHE_DIR, hash_object.hexdigest(), str(shard_rank))
os.makedirs(cache_dir, exist_ok=True)
return cache_dir
cache_dir = os.path.join("/cache", "chunks", hash_object.hexdigest(), str(shard_rank))
os.makedirs(cache_dir, exist_ok=True)
return cache_dir


@dataclass
class RemoteDir:
"""Holds a remote URL to a directory and a cache directory where the data will be downloaded."""

cache_dir: str
remote: str
37 changes: 21 additions & 16 deletions tests/tests_data/streaming/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch
from lightning import seed_everything
from lightning.data.streaming import Cache, functions
from lightning.data.streaming.dataset import StreamingDataset, _try_create_cache_dir
from lightning.data.streaming.dataset import RemoteDir, StreamingDataset, _try_create_cache_dir
from lightning.data.streaming.item_loader import TokensLoader
from lightning.data.streaming.shuffle import FullShuffle, NoShuffle
from lightning.data.utilities.env import _DistributedEnv
Expand Down Expand Up @@ -55,17 +55,6 @@ def test_streaming_dataset(tmpdir, monkeypatch):
assert len(dataloader) == 6


@mock.patch.dict(os.environ, {"LIGHTNING_CLUSTER_ID": "123", "LIGHTNING_CLOUD_PROJECT_ID": "456"})
@mock.patch("lightning.data.streaming.dataset.os.makedirs")
@pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows")
def test_create_cache_dir_in_lightning_cloud(makedirs_mock):
# Locally, we can't actually write to the root filesystem with user privileges, so we need to mock the call
dataset = StreamingDataset("dummy")
with pytest.raises(FileNotFoundError, match="/0` doesn't exist"):
iter(dataset)
makedirs_mock.assert_called()


@pytest.mark.parametrize("drop_last", [False, True])
def test_streaming_dataset_distributed_no_shuffle(drop_last, tmpdir):
seed_everything(42)
Expand Down Expand Up @@ -307,7 +296,7 @@ def test_dataset_cache_recreation(tmpdir):

def test_try_create_cache_dir():
with mock.patch.dict(os.environ, {}, clear=True):
assert _try_create_cache_dir("any") is None
assert os.path.join("chunks", "100b8cad7cf2a56f6df78f171f97a1ec", "0") in _try_create_cache_dir("any")

# the cache dir creating at /cache requires root privileges, so we need to mock `os.makedirs()`
with (
Expand Down Expand Up @@ -493,9 +482,6 @@ def test_dataset_for_text_tokens_distributed_num_workers_end_to_end(tmpdir, monk
monkeypatch.setenv("DATA_OPTIMIZER_CACHE_FOLDER", cache_dir)
monkeypatch.setenv("DATA_OPTIMIZER_DATA_CACHE_FOLDER", cache_dir)

def fn(item):
return torch.arange(item[0], item[0] + 20).to(torch.int)

functions.optimize(
optimize_fn, inputs, output_dir=str(tmpdir), num_workers=2, chunk_size=2, reorder_files=False, num_downloaders=1
)
Expand Down Expand Up @@ -534,3 +520,22 @@ def fn(item):

for batch_idx, batch in enumerate(dataloader):
assert [batch[0][0].item(), batch[1][0].item()] == expected[batch_idx]


@pytest.mark.skipif(sys.platform == "win32" or sys.platform == "darwin", reason="Not tested on windows and MacOs")
def test_s3_streaming_dataset(tmpdir):
dataset = StreamingDataset(input_dir="s3://pl-flash-data/optimized_tiny_imagenet")
assert len(dataset) == 1000
expected = torch.CharTensor([40, 41, 8, 29, 67]).to(dtype=torch.uint8)
generated = dataset[0][0][0][:5]
assert torch.equal(generated, expected)

dataset = StreamingDataset(
input_dir=RemoteDir(cache_dir=str(tmpdir), remote="s3://pl-flash-data/optimized_tiny_imagenet")
)
assert len(dataset) == 1000
expected = torch.CharTensor([40, 41, 8, 29, 67]).to(dtype=torch.uint8)
generated = dataset[0][0][0][:5]
assert torch.equal(generated, expected)

assert sorted(os.listdir(tmpdir)) == ["chunk-0-0.bin", "index.json"]