Skip to content

Commit

Permalink
Add dataset creation (#18940)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: thomas <[email protected]>
  • Loading branch information
3 people authored Nov 4, 2023
1 parent 8d68607 commit faa64c5
Show file tree
Hide file tree
Showing 8 changed files with 112 additions and 21 deletions.
2 changes: 1 addition & 1 deletion requirements/app/app.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
lightning-cloud == 0.5.48 # Must be pinned to ensure compatibility
lightning-cloud == 0.5.50 # Must be pinned to ensure compatibility
packaging
typing-extensions >=4.4.0, <4.8.0
deepdiff >=5.7.0, <6.6.0
Expand Down
8 changes: 4 additions & 4 deletions src/lightning/data/streaming/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from lightning.data.datasets.env import _DistributedEnv
from lightning.data.streaming.constants import (
_INDEX_FILENAME,
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_48,
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_50,
_TORCH_GREATER_EQUAL_2_1_0,
)
from lightning.data.streaming.item_loader import BaseItemLoader
Expand All @@ -29,7 +29,7 @@

logger = logging.Logger(__name__)

if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_48:
if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_50:
from lightning_cloud.resolver import _resolve_dir


Expand Down Expand Up @@ -67,8 +67,8 @@ def __init__(
if not _TORCH_GREATER_EQUAL_2_1_0:
raise ModuleNotFoundError("PyTorch version 2.1 or higher is required to use the cache.")

if not _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_48:
raise ModuleNotFoundError("Lightning Cloud 0.5.48 or higher is required to use the cache.")
if not _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_50:
raise ModuleNotFoundError("Lightning Cloud 0.5.50 or higher is required to use the cache.")

input_dir = _resolve_dir(input_dir)
self._cache_dir = input_dir.path
Expand Down
20 changes: 19 additions & 1 deletion src/lightning/data/streaming/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from lightning.data.streaming.sampler import ChunkedIndex

if _TORCH_GREATER_EQUAL_2_1_0:
from torch.utils._pytree import treespec_loads
from torch.utils._pytree import tree_unflatten, treespec_loads


class ChunksConfig:
Expand Down Expand Up @@ -83,6 +83,24 @@ def data_format(self) -> Any:
raise RuntimeError("The config should be defined.")
return self._config["data_format"]

@property
def data_format_unflattened(self) -> Any:
if self._config is None:
raise RuntimeError("The config should be defined.")
return tree_unflatten(self._config["data_format"], self._config["data_spec"])

@property
def compression(self) -> Any:
if self._config is None:
raise RuntimeError("The config should be defined.")
return self._config["compression"]

@property
def chunk_bytes(self) -> int:
if self._config is None:
raise RuntimeError("The config should be defined.")
return self._config["chunk_bytes"]

@property
def config(self) -> Dict[str, Any]:
if self._config is None:
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/data/streaming/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# This is required for full pytree serialization / deserialization support
_TORCH_GREATER_EQUAL_2_1_0 = RequirementCache("torch>=2.1.0")
_VIZ_TRACKER_AVAILABLE = RequirementCache("viztracer")
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_48 = RequirementCache("lightning-cloud>=0.5.48")
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_50 = RequirementCache("lightning-cloud>=0.5.50")
_BOTO3_AVAILABLE = RequirementCache("boto3")

# DON'T CHANGE ORDER
Expand Down
70 changes: 61 additions & 9 deletions src/lightning/data/streaming/data_processor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import json
import logging
import os
import signal
import tempfile
import traceback
import types
from abc import abstractmethod
from dataclasses import dataclass
from multiprocessing import Process, Queue
from queue import Empty
from shutil import copyfile, rmtree
Expand All @@ -23,7 +25,7 @@
_BOTO3_AVAILABLE,
_DEFAULT_FAST_DEV_RUN_ITEMS,
_INDEX_FILENAME,
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_48,
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_50,
_TORCH_GREATER_EQUAL_2_1_0,
)
from lightning.fabric.accelerators.cuda import is_cuda_available
Expand All @@ -35,10 +37,12 @@
from lightning.fabric.utilities.distributed import group as _group

if _TORCH_GREATER_EQUAL_2_1_0:
from torch.utils._pytree import tree_flatten, tree_unflatten
from torch.utils._pytree import tree_flatten, tree_unflatten, treespec_loads

if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_48:
if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_50:
from lightning_cloud.openapi import V1DatasetType
from lightning_cloud.resolver import _resolve_dir
from lightning_cloud.utils.dataset import _create_dataset


if _BOTO3_AVAILABLE:
Expand Down Expand Up @@ -191,7 +195,7 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_
)
except Exception as e:
print(e)
if os.path.isdir(output_dir.path):
elif os.path.isdir(output_dir.path):
copyfile(local_filepath, os.path.join(output_dir.path, os.path.basename(local_filepath)))
else:
raise ValueError(f"The provided {output_dir.path} isn't supported.")
Expand Down Expand Up @@ -506,6 +510,16 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
Process.__init__(self)


@dataclass
class _Result:
size: Optional[int] = None
num_bytes: Optional[str] = None
data_format: Optional[str] = None
compression: Optional[str] = None
num_chunks: Optional[int] = None
num_bytes_per_chunk: Optional[List[int]] = None


T = TypeVar("T")


Expand Down Expand Up @@ -545,8 +559,8 @@ def listdir(self, path: str) -> List[str]:
def __init__(self) -> None:
self._name: Optional[str] = None

def _done(self, delete_cached_files: bool, output_dir: Dir) -> None:
pass
def _done(self, size: int, delete_cached_files: bool, output_dir: Dir) -> _Result:
return _Result(size=size)


class DataChunkRecipe(DataRecipe):
Expand Down Expand Up @@ -576,7 +590,7 @@ def prepare_structure(self, input_dir: Optional[str]) -> List[T]:
def prepare_item(self, item_metadata: T) -> Any: # type: ignore
"""The return of this `prepare_item` method is persisted in chunked binary files."""

def _done(self, delete_cached_files: bool, output_dir: Dir) -> None:
def _done(self, size: int, delete_cached_files: bool, output_dir: Dir) -> _Result:
num_nodes = _get_num_nodes()
cache_dir = _get_cache_dir()

Expand All @@ -589,6 +603,26 @@ def _done(self, delete_cached_files: bool, output_dir: Dir) -> None:
merge_cache._merge_no_wait(node_rank if num_nodes > 1 else None)
self._upload_index(output_dir, cache_dir, num_nodes, node_rank)

if num_nodes == node_rank + 1:
with open(os.path.join(cache_dir, _INDEX_FILENAME)) as f:
config = json.load(f)

size = sum([c["dim"] if c["dim"] is not None else c["chunk_size"] for c in config["chunks"]])
num_bytes = sum([c["chunk_bytes"] for c in config["chunks"]])
data_format = tree_unflatten(config["config"]["data_format"], treespec_loads(config["config"]["data_spec"]))

return _Result(
size=size,
num_bytes=num_bytes,
data_format=data_format,
compression=config["config"]["compression"],
num_chunks=len(config["chunks"]),
num_bytes_per_chunk=[c["chunk_size"] for c in config["chunks"]],
)
return _Result(
size=size,
)

def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_rank: Optional[int]) -> None:
"""This method upload the index file to the remote cloud directory."""
if output_dir.path is None and output_dir.url is None:
Expand Down Expand Up @@ -764,13 +798,31 @@ def run(self, data_recipe: DataRecipe) -> None:
has_failed = True
break

num_nodes = _get_num_nodes()
# TODO: Understand why it hangs.
if _get_num_nodes() == 1:
if num_nodes == 1:
for w in self.workers:
w.join(0)

print("Workers are finished.")
data_recipe._done(self.delete_cached_files, self.output_dir)
result = data_recipe._done(num_items, self.delete_cached_files, self.output_dir)

if num_nodes == _get_node_rank() + 1:
_create_dataset(
input_dir=self.input_dir.path,
storage_dir=self.output_dir.path,
dataset_type=V1DatasetType.CHUNKED
if isinstance(data_recipe, DataChunkRecipe)
else V1DatasetType.TRANSFORMED,
empty=False,
size=result.size,
num_bytes=result.num_bytes,
data_format=result.data_format,
compression=result.compression,
num_chunks=result.num_chunks,
num_bytes_per_chunk=result.num_bytes_per_chunk,
)

print("Finished data processing!")

# TODO: Understand why it is required to avoid long shutdown.
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/data/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@

from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv
from lightning.data.streaming import Cache
from lightning.data.streaming.constants import _INDEX_FILENAME, _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_48
from lightning.data.streaming.constants import _INDEX_FILENAME, _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_50
from lightning.data.streaming.item_loader import BaseItemLoader
from lightning.data.streaming.sampler import ChunkedIndex
from lightning.data.streaming.shuffle import FullShuffle, NoShuffle, Shuffle

if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_48:
if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_50:
from lightning_cloud.resolver import _resolve_dir


Expand Down
4 changes: 2 additions & 2 deletions src/lightning/data/streaming/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
from types import FunctionType
from typing import Any, Callable, Optional, Sequence, Union

from lightning.data.streaming.constants import _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_48, _TORCH_GREATER_EQUAL_2_1_0
from lightning.data.streaming.constants import _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_50, _TORCH_GREATER_EQUAL_2_1_0
from lightning.data.streaming.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe

if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_48:
if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_50:
from lightning_cloud.resolver import _assert_dir_has_index_file, _assert_dir_is_empty, _execute, _resolve_dir

if _TORCH_GREATER_EQUAL_2_1_0:
Expand Down
23 changes: 22 additions & 1 deletion tests/tests_data/streaming/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def copy_file(local_filepath, *args):
called = True
from shutil import copyfile

copyfile(local_filepath, os.path.join(remote_output_dir.path, os.path.basename(local_filepath)))
copyfile(local_filepath, os.path.join(remote_output_dir, os.path.basename(local_filepath)))

s3_client.client.upload_file = copy_file

Expand Down Expand Up @@ -420,8 +420,14 @@ def _broadcast_object(self, obj: Any) -> Any:
def test_data_processsor_distributed(fast_dev_run, delete_cached_files, tmpdir, monkeypatch):
"""This test ensures the data optimizer works in a fully distributed settings."""

seed_everything(42)

monkeypatch.setattr(data_processor_module.os, "_exit", mock.MagicMock())

_create_dataset_mock = mock.MagicMock()

monkeypatch.setattr(data_processor_module, "_create_dataset", _create_dataset_mock)

from PIL import Image

input_dir = os.path.join(tmpdir, "dataset")
Expand Down Expand Up @@ -501,6 +507,21 @@ def test_data_processsor_distributed(fast_dev_run, delete_cached_files, tmpdir,

assert sorted(os.listdir(remote_output_dir)) == expected

_create_dataset_mock.assert_called()

assert _create_dataset_mock._mock_mock_calls[0].kwargs == {
"input_dir": str(input_dir),
"storage_dir": str(remote_output_dir),
"dataset_type": "CHUNKED",
"empty": False,
"size": 30,
"num_bytes": 26657,
"data_format": "jpeg",
"compression": None,
"num_chunks": 16,
"num_bytes_per_chunk": [2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2],
}


class TextTokenizeRecipe(DataChunkRecipe):
def prepare_structure(self, input_dir: str) -> List[Any]:
Expand Down

0 comments on commit faa64c5

Please sign in to comment.