Skip to content

Commit

Permalink
Add numpy support for the StreamingDataset 1/2 (#19050)
Browse files Browse the repository at this point in the history
Co-authored-by: thomas <[email protected]>
(cherry picked from commit 7eca9c1)
  • Loading branch information
tchaton authored and Borda committed Dec 19, 2023
1 parent ed7746f commit 6e2dc5e
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 1 deletion.
4 changes: 4 additions & 0 deletions src/lightning/data/streaming/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import os
from pathlib import Path

import numpy as np
import torch
from lightning_utilities.core.imports import RequirementCache

Expand Down Expand Up @@ -52,4 +53,7 @@
19: torch.bool,
}

_NUMPY_SCTYPES = [v for values in np.sctypes.values() for v in values]
_NUMPY_DTYPES_MAPPING = {i: np.dtype(v) for i, v in enumerate(_NUMPY_SCTYPES)}

_TIME_FORMAT = "%Y-%m-%d_%H-%M-%S.%fZ"
59 changes: 58 additions & 1 deletion src/lightning/data/streaming/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import torch
from lightning_utilities.core.imports import RequirementCache

from lightning.data.streaming.constants import _TORCH_DTYPES_MAPPING
from lightning.data.streaming.constants import _NUMPY_DTYPES_MAPPING, _TORCH_DTYPES_MAPPING

_PIL_AVAILABLE = RequirementCache("PIL")
_TORCH_VISION_AVAILABLE = RequirementCache("torchvision")
Expand Down Expand Up @@ -200,6 +200,61 @@ def can_serialize(self, item: torch.Tensor) -> bool:
return isinstance(item, torch.Tensor) and type(item) == torch.Tensor and len(item.shape) == 1


class NumpySerializer(Serializer):
"""The NumpySerializer serialize and deserialize numpy to and from bytes."""

def __init__(self) -> None:
super().__init__()
self._dtype_to_indice = {v: k for k, v in _NUMPY_DTYPES_MAPPING.items()}

def serialize(self, item: np.ndarray) -> Tuple[bytes, Optional[str]]:
dtype_indice = self._dtype_to_indice[item.dtype]
data = [np.uint32(dtype_indice).tobytes()]
data.append(np.uint32(len(item.shape)).tobytes())
for dim in item.shape:
data.append(np.uint32(dim).tobytes())
data.append(item.tobytes(order="C"))
return b"".join(data), None

def deserialize(self, data: bytes) -> np.ndarray:
dtype_indice = np.frombuffer(data[0:4], np.uint32).item()
dtype = _NUMPY_DTYPES_MAPPING[dtype_indice]
shape_size = np.frombuffer(data[4:8], np.uint32).item()
shape = []
for shape_idx in range(shape_size):
shape.append(np.frombuffer(data[8 + 4 * shape_idx : 8 + 4 * (shape_idx + 1)], np.uint32).item())
tensor = np.frombuffer(data[8 + 4 * (shape_idx + 1) : len(data)], dtype=dtype)
if tensor.shape == shape:
return tensor
return np.reshape(tensor, shape)

def can_serialize(self, item: np.ndarray) -> bool:
return isinstance(item, np.ndarray) and type(item) == np.ndarray and len(item.shape) > 1


class NoHeaderNumpySerializer(Serializer):
"""The NoHeaderNumpySerializer serialize and deserialize numpy to and from bytes."""

def __init__(self) -> None:
super().__init__()
self._dtype_to_indice = {v: k for k, v in _NUMPY_DTYPES_MAPPING.items()}
self._dtype: Optional[np.dtype] = None

def setup(self, data_format: str) -> None:
self._dtype = _NUMPY_DTYPES_MAPPING[int(data_format.split(":")[1])]

def serialize(self, item: np.ndarray) -> Tuple[bytes, Optional[str]]:
dtype_indice: int = self._dtype_to_indice[item.dtype]
return item.tobytes(order="C"), f"no_header_numpy:{dtype_indice}"

def deserialize(self, data: bytes) -> np.ndarray:
assert self._dtype
return np.frombuffer(data, dtype=self._dtype)

def can_serialize(self, item: np.ndarray) -> bool:
return isinstance(item, np.ndarray) and type(item) == np.ndarray and len(item.shape) == 1


class PickleSerializer(Serializer):
"""The PickleSerializer serialize and deserialize python objects to and from bytes."""

Expand Down Expand Up @@ -263,6 +318,8 @@ def can_serialize(self, data: Any) -> bool:
"int": IntSerializer(),
"jpeg": JPEGSerializer(),
"bytes": BytesSerializer(),
"no_header_numpy": NoHeaderNumpySerializer(),
"numpy": NumpySerializer(),
"no_header_tensor": NoHeaderTensorSerializer(),
"tensor": TensorSerializer(),
"pickle": PickleSerializer(),
Expand Down
37 changes: 37 additions & 0 deletions tests/tests_data/streaming/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@
from lightning import seed_everything
from lightning.data.streaming.serializers import (
_AV_AVAILABLE,
_NUMPY_DTYPES_MAPPING,
_SERIALIZERS,
_TORCH_DTYPES_MAPPING,
_TORCH_VISION_AVAILABLE,
IntSerializer,
NoHeaderNumpySerializer,
NoHeaderTensorSerializer,
NumpySerializer,
PickleSerializer,
PILSerializer,
TensorSerializer,
Expand All @@ -44,6 +47,8 @@ def test_serializers():
"int",
"jpeg",
"bytes",
"no_header_numpy",
"numpy",
"no_header_tensor",
"tensor",
"pickle",
Expand Down Expand Up @@ -124,6 +129,25 @@ def test_tensor_serializer():
assert np.mean(ratio_bytes) > 2


@pytest.mark.skipif(sys.platform == "win32", reason="Not supported on windows")
def test_numpy_serializer():
seed_everything(42)

serializer_tensor = NumpySerializer()

shapes = [(10,), (10, 10), (10, 10, 10), (10, 10, 10, 5), (10, 10, 10, 5, 4)]
for dtype in _NUMPY_DTYPES_MAPPING.values():
# Those types aren't supported
if dtype.name in ["object", "bytes", "str", "void"]:
continue
for shape in shapes:
tensor = np.ones(shape, dtype=dtype)
data, _ = serializer_tensor.serialize(tensor)
deserialized_tensor = serializer_tensor.deserialize(data)
assert deserialized_tensor.dtype == dtype
np.testing.assert_equal(tensor, deserialized_tensor)


def test_assert_bfloat16_tensor_serializer():
serializer = TensorSerializer()
tensor = torch.ones((10,), dtype=torch.bfloat16)
Expand All @@ -143,6 +167,19 @@ def test_assert_no_header_tensor_serializer():
assert torch.equal(t, new_t)


def test_assert_no_header_numpy_serializer():
serializer = NoHeaderNumpySerializer()
t = np.ones((10,))
assert serializer.can_serialize(t)
data, name = serializer.serialize(t)
assert name == "no_header_numpy:10"
assert serializer._dtype is None
serializer.setup(name)
assert serializer._dtype == np.dtype("float64")
new_t = serializer.deserialize(data)
np.testing.assert_equal(t, new_t)


@pytest.mark.skipif(
condition=not _TORCH_VISION_AVAILABLE or not _AV_AVAILABLE, reason="Requires: ['torchvision', 'av']"
)
Expand Down

0 comments on commit 6e2dc5e

Please sign in to comment.