Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add numpy support for the StreamingDataset 1/2 #19050

Merged
merged 4 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/lightning/data/streaming/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import os
from pathlib import Path

import numpy as np
import torch
from lightning_utilities.core.imports import RequirementCache

Expand Down Expand Up @@ -52,4 +53,7 @@
19: torch.bool,
}

_NUMPY_SCTYPES = [v for values in np.sctypes.values() for v in values]
_NUMPY_DTYPES_MAPPING = {i: np.dtype(v) for i, v in enumerate(_NUMPY_SCTYPES)}

_TIME_FORMAT = "%Y-%m-%d_%H-%M-%S.%fZ"
59 changes: 58 additions & 1 deletion src/lightning/data/streaming/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import torch
from lightning_utilities.core.imports import RequirementCache

from lightning.data.streaming.constants import _TORCH_DTYPES_MAPPING
from lightning.data.streaming.constants import _NUMPY_DTYPES_MAPPING, _TORCH_DTYPES_MAPPING

_PIL_AVAILABLE = RequirementCache("PIL")
_TORCH_VISION_AVAILABLE = RequirementCache("torchvision")
Expand Down Expand Up @@ -200,6 +200,61 @@ def can_serialize(self, item: torch.Tensor) -> bool:
return isinstance(item, torch.Tensor) and type(item) == torch.Tensor and len(item.shape) == 1


class NumpySerializer(Serializer):
"""The NumpySerializer serialize and deserialize numpy to and from bytes."""

def __init__(self) -> None:
super().__init__()
self._dtype_to_indice = {v: k for k, v in _NUMPY_DTYPES_MAPPING.items()}
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def serialize(self, item: np.ndarray) -> Tuple[bytes, Optional[str]]:
dtype_indice = self._dtype_to_indice[item.dtype]
data = [np.uint32(dtype_indice).tobytes()]
data.append(np.uint32(len(item.shape)).tobytes())
for dim in item.shape:
data.append(np.uint32(dim).tobytes())
data.append(item.tobytes(order="C"))
return b"".join(data), None
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def deserialize(self, data: bytes) -> np.ndarray:
dtype_indice = np.frombuffer(data[0:4], np.uint32).item()
dtype = _NUMPY_DTYPES_MAPPING[dtype_indice]
shape_size = np.frombuffer(data[4:8], np.uint32).item()
shape = []
for shape_idx in range(shape_size):
shape.append(np.frombuffer(data[8 + 4 * shape_idx : 8 + 4 * (shape_idx + 1)], np.uint32).item())
tensor = np.frombuffer(data[8 + 4 * (shape_idx + 1) : len(data)], dtype=dtype)
Comment on lines +225 to +226
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should leave comments when we hardcode the numbers

if tensor.shape == shape:
return tensor
return np.reshape(tensor, shape)

def can_serialize(self, item: np.ndarray) -> bool:
return isinstance(item, np.ndarray) and type(item) == np.ndarray and len(item.shape) > 1


class NoHeaderNumpySerializer(Serializer):
"""The NoHeaderNumpySerializer serialize and deserialize numpy to and from bytes."""

def __init__(self) -> None:
super().__init__()
self._dtype_to_indice = {v: k for k, v in _NUMPY_DTYPES_MAPPING.items()}
self._dtype: Optional[np.dtype] = None

def setup(self, data_format: str) -> None:
self._dtype = _NUMPY_DTYPES_MAPPING[int(data_format.split(":")[1])]

def serialize(self, item: np.ndarray) -> Tuple[bytes, Optional[str]]:
dtype_indice: int = self._dtype_to_indice[item.dtype]
return item.tobytes(order="C"), f"no_header_numpy:{dtype_indice}"

def deserialize(self, data: bytes) -> np.ndarray:
assert self._dtype
return np.frombuffer(data, dtype=self._dtype)

def can_serialize(self, item: np.ndarray) -> bool:
return isinstance(item, np.ndarray) and type(item) == np.ndarray and len(item.shape) == 1


class PickleSerializer(Serializer):
"""The PickleSerializer serialize and deserialize python objects to and from bytes."""

Expand Down Expand Up @@ -263,6 +318,8 @@ def can_serialize(self, data: Any) -> bool:
"int": IntSerializer(),
"jpeg": JPEGSerializer(),
"bytes": BytesSerializer(),
"no_header_numpy": NoHeaderNumpySerializer(),
"numpy": NumpySerializer(),
"no_header_tensor": NoHeaderTensorSerializer(),
"tensor": TensorSerializer(),
"pickle": PickleSerializer(),
Expand Down
37 changes: 37 additions & 0 deletions tests/tests_data/streaming/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@
from lightning import seed_everything
from lightning.data.streaming.serializers import (
_AV_AVAILABLE,
_NUMPY_DTYPES_MAPPING,
_SERIALIZERS,
_TORCH_DTYPES_MAPPING,
_TORCH_VISION_AVAILABLE,
IntSerializer,
NoHeaderNumpySerializer,
NoHeaderTensorSerializer,
NumpySerializer,
PickleSerializer,
PILSerializer,
TensorSerializer,
Expand All @@ -44,6 +47,8 @@ def test_serializers():
"int",
"jpeg",
"bytes",
"no_header_numpy",
"numpy",
"no_header_tensor",
"tensor",
"pickle",
Expand Down Expand Up @@ -124,6 +129,25 @@ def test_tensor_serializer():
assert np.mean(ratio_bytes) > 2


@pytest.mark.skipif(sys.platform == "win32", reason="Not supported on windows")
def test_numpy_serializer():
seed_everything(42)

serializer_tensor = NumpySerializer()

shapes = [(10,), (10, 10), (10, 10, 10), (10, 10, 10, 5), (10, 10, 10, 5, 4)]
for dtype in _NUMPY_DTYPES_MAPPING.values():
# Those types aren't supported
if dtype.name in ["object", "bytes", "str", "void"]:
continue
tchaton marked this conversation as resolved.
Show resolved Hide resolved
for shape in shapes:
tensor = np.ones(shape, dtype=dtype)
data, _ = serializer_tensor.serialize(tensor)
deserialized_tensor = serializer_tensor.deserialize(data)
assert deserialized_tensor.dtype == dtype
np.testing.assert_equal(tensor, deserialized_tensor)


def test_assert_bfloat16_tensor_serializer():
serializer = TensorSerializer()
tensor = torch.ones((10,), dtype=torch.bfloat16)
Expand All @@ -143,6 +167,19 @@ def test_assert_no_header_tensor_serializer():
assert torch.equal(t, new_t)


def test_assert_no_header_numpy_serializer():
serializer = NoHeaderNumpySerializer()
t = np.ones((10,))
assert serializer.can_serialize(t)
data, name = serializer.serialize(t)
assert name == "no_header_numpy:10"
assert serializer._dtype is None
serializer.setup(name)
assert serializer._dtype == np.dtype("float64")
new_t = serializer.deserialize(data)
np.testing.assert_equal(t, new_t)


@pytest.mark.skipif(
condition=not _TORCH_VISION_AVAILABLE or not _AV_AVAILABLE, reason="Requires: ['torchvision', 'av']"
)
Expand Down