diff --git a/src/lightning/data/streaming/constants.py b/src/lightning/data/streaming/constants.py index 759c42e07576d..28f56c286e44e 100644 --- a/src/lightning/data/streaming/constants.py +++ b/src/lightning/data/streaming/constants.py @@ -14,6 +14,7 @@ import os from pathlib import Path +import numpy as np import torch from lightning_utilities.core.imports import RequirementCache @@ -52,4 +53,7 @@ 19: torch.bool, } +_NUMPY_SCTYPES = [v for values in np.sctypes.values() for v in values] +_NUMPY_DTYPES_MAPPING = {i: np.dtype(v) for i, v in enumerate(_NUMPY_SCTYPES)} + _TIME_FORMAT = "%Y-%m-%d_%H-%M-%S.%fZ" diff --git a/src/lightning/data/streaming/serializers.py b/src/lightning/data/streaming/serializers.py index 731e7259847cd..0c40a68dc097a 100644 --- a/src/lightning/data/streaming/serializers.py +++ b/src/lightning/data/streaming/serializers.py @@ -22,7 +22,7 @@ import torch from lightning_utilities.core.imports import RequirementCache -from lightning.data.streaming.constants import _TORCH_DTYPES_MAPPING +from lightning.data.streaming.constants import _NUMPY_DTYPES_MAPPING, _TORCH_DTYPES_MAPPING _PIL_AVAILABLE = RequirementCache("PIL") _TORCH_VISION_AVAILABLE = RequirementCache("torchvision") @@ -200,6 +200,61 @@ def can_serialize(self, item: torch.Tensor) -> bool: return isinstance(item, torch.Tensor) and type(item) == torch.Tensor and len(item.shape) == 1 +class NumpySerializer(Serializer): + """The NumpySerializer serialize and deserialize numpy to and from bytes.""" + + def __init__(self) -> None: + super().__init__() + self._dtype_to_indice = {v: k for k, v in _NUMPY_DTYPES_MAPPING.items()} + + def serialize(self, item: np.ndarray) -> Tuple[bytes, Optional[str]]: + dtype_indice = self._dtype_to_indice[item.dtype] + data = [np.uint32(dtype_indice).tobytes()] + data.append(np.uint32(len(item.shape)).tobytes()) + for dim in item.shape: + data.append(np.uint32(dim).tobytes()) + data.append(item.tobytes(order="C")) + return b"".join(data), None + + def deserialize(self, data: bytes) -> np.ndarray: + dtype_indice = np.frombuffer(data[0:4], np.uint32).item() + dtype = _NUMPY_DTYPES_MAPPING[dtype_indice] + shape_size = np.frombuffer(data[4:8], np.uint32).item() + shape = [] + for shape_idx in range(shape_size): + shape.append(np.frombuffer(data[8 + 4 * shape_idx : 8 + 4 * (shape_idx + 1)], np.uint32).item()) + tensor = np.frombuffer(data[8 + 4 * (shape_idx + 1) : len(data)], dtype=dtype) + if tensor.shape == shape: + return tensor + return np.reshape(tensor, shape) + + def can_serialize(self, item: np.ndarray) -> bool: + return isinstance(item, np.ndarray) and type(item) == np.ndarray and len(item.shape) > 1 + + +class NoHeaderNumpySerializer(Serializer): + """The NoHeaderNumpySerializer serialize and deserialize numpy to and from bytes.""" + + def __init__(self) -> None: + super().__init__() + self._dtype_to_indice = {v: k for k, v in _NUMPY_DTYPES_MAPPING.items()} + self._dtype: Optional[np.dtype] = None + + def setup(self, data_format: str) -> None: + self._dtype = _NUMPY_DTYPES_MAPPING[int(data_format.split(":")[1])] + + def serialize(self, item: np.ndarray) -> Tuple[bytes, Optional[str]]: + dtype_indice: int = self._dtype_to_indice[item.dtype] + return item.tobytes(order="C"), f"no_header_numpy:{dtype_indice}" + + def deserialize(self, data: bytes) -> np.ndarray: + assert self._dtype + return np.frombuffer(data, dtype=self._dtype) + + def can_serialize(self, item: np.ndarray) -> bool: + return isinstance(item, np.ndarray) and type(item) == np.ndarray and len(item.shape) == 1 + + class PickleSerializer(Serializer): """The PickleSerializer serialize and deserialize python objects to and from bytes.""" @@ -263,6 +318,8 @@ def can_serialize(self, data: Any) -> bool: "int": IntSerializer(), "jpeg": JPEGSerializer(), "bytes": BytesSerializer(), + "no_header_numpy": NoHeaderNumpySerializer(), + "numpy": NumpySerializer(), "no_header_tensor": NoHeaderTensorSerializer(), "tensor": TensorSerializer(), "pickle": PickleSerializer(), diff --git a/tests/tests_data/streaming/test_serializer.py b/tests/tests_data/streaming/test_serializer.py index 5ec6ac15efc5e..5f0129e6c2cbb 100644 --- a/tests/tests_data/streaming/test_serializer.py +++ b/tests/tests_data/streaming/test_serializer.py @@ -21,11 +21,14 @@ from lightning import seed_everything from lightning.data.streaming.serializers import ( _AV_AVAILABLE, + _NUMPY_DTYPES_MAPPING, _SERIALIZERS, _TORCH_DTYPES_MAPPING, _TORCH_VISION_AVAILABLE, IntSerializer, + NoHeaderNumpySerializer, NoHeaderTensorSerializer, + NumpySerializer, PickleSerializer, PILSerializer, TensorSerializer, @@ -44,6 +47,8 @@ def test_serializers(): "int", "jpeg", "bytes", + "no_header_numpy", + "numpy", "no_header_tensor", "tensor", "pickle", @@ -124,6 +129,25 @@ def test_tensor_serializer(): assert np.mean(ratio_bytes) > 2 +@pytest.mark.skipif(sys.platform == "win32", reason="Not supported on windows") +def test_numpy_serializer(): + seed_everything(42) + + serializer_tensor = NumpySerializer() + + shapes = [(10,), (10, 10), (10, 10, 10), (10, 10, 10, 5), (10, 10, 10, 5, 4)] + for dtype in _NUMPY_DTYPES_MAPPING.values(): + # Those types aren't supported + if dtype.name in ["object", "bytes", "str", "void"]: + continue + for shape in shapes: + tensor = np.ones(shape, dtype=dtype) + data, _ = serializer_tensor.serialize(tensor) + deserialized_tensor = serializer_tensor.deserialize(data) + assert deserialized_tensor.dtype == dtype + np.testing.assert_equal(tensor, deserialized_tensor) + + def test_assert_bfloat16_tensor_serializer(): serializer = TensorSerializer() tensor = torch.ones((10,), dtype=torch.bfloat16) @@ -143,6 +167,19 @@ def test_assert_no_header_tensor_serializer(): assert torch.equal(t, new_t) +def test_assert_no_header_numpy_serializer(): + serializer = NoHeaderNumpySerializer() + t = np.ones((10,)) + assert serializer.can_serialize(t) + data, name = serializer.serialize(t) + assert name == "no_header_numpy:10" + assert serializer._dtype is None + serializer.setup(name) + assert serializer._dtype == np.dtype("float64") + new_t = serializer.deserialize(data) + np.testing.assert_equal(t, new_t) + + @pytest.mark.skipif( condition=not _TORCH_VISION_AVAILABLE or not _AV_AVAILABLE, reason="Requires: ['torchvision', 'av']" )