From 233470d81410160baaa85901c5e1e3c2caa3a060 Mon Sep 17 00:00:00 2001 From: Jordan Wilke Date: Wed, 17 Jul 2024 16:32:15 -0700 Subject: [PATCH 1/2] Registered serializer for common classes of additional array-like objects --- pydra/utils/hash.py | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/pydra/utils/hash.py b/pydra/utils/hash.py index 3ba3e97b4..a222b5acc 100644 --- a/pydra/utils/hash.py +++ b/pydra/utils/hash.py @@ -43,6 +43,27 @@ else: HAVE_NUMPY = True +try: + import pandas +except ImportError: + HAVE_PANDAS = False +else: + HAVE_PANDAS = True + +try: + import torch +except ImportError: + HAVE_PYTORCH = False +else: + HAVE_PYTORCH = True + +try: + import tensorflow +except ImportError: + HAVE_TENSORFLOW = False +else: + HAVE_TENSORFLOW = True + __all__ = ( "hash_function", "hash_object", @@ -564,5 +585,27 @@ def bytes_repr_numpy(obj: numpy.ndarray, cache: Cache) -> Iterator[bytes]: else: yield obj.tobytes(order="C") +if HAVE_PYTORCH: + + @register_serializer(torch.Tensor) + def bytes_repr_torch(obj: torch.Tensor, cache: Cache) -> Iterator[bytes]: + yield f"{obj.__class__.__module__}{obj.__class__.__name__}:".encode() + yield from bytes_repr_numpy(obj.numpy(), cache) + + +if HAVE_TENSORFLOW: + + @register_serializer(tensorflow.Tensor) + def bytes_repr_tensorflow(obj: tensorflow.Tensor, cache: Cache) -> Iterator[bytes]: + yield f"{obj.__class__.__module__}{obj.__class__.__name__}:".encode() + yield from bytes_repr_numpy(obj.numpy(), cache) + + +if HAVE_PANDAS: + + @register_serializer(pandas.DataFrame) + def bytes_repr_pandas(obj: pandas.DataFrame, cache: Cache) -> Iterator[bytes]: + yield f"{obj.__class__.__module__}{obj.__class__.__name__}:".encode() + yield from bytes_repr_numpy(obj.to_numpy(), cache) NUMPY_CHUNK_LEN = 8192 From 83c721b1fd8e57046d50bab745ce77e382a4dcfd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 18 Jul 2024 02:40:30 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pydra/utils/hash.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pydra/utils/hash.py b/pydra/utils/hash.py index a222b5acc..816a0e1a2 100644 --- a/pydra/utils/hash.py +++ b/pydra/utils/hash.py @@ -585,6 +585,7 @@ def bytes_repr_numpy(obj: numpy.ndarray, cache: Cache) -> Iterator[bytes]: else: yield obj.tobytes(order="C") + if HAVE_PYTORCH: @register_serializer(torch.Tensor) @@ -608,4 +609,5 @@ def bytes_repr_pandas(obj: pandas.DataFrame, cache: Cache) -> Iterator[bytes]: yield f"{obj.__class__.__module__}{obj.__class__.__name__}:".encode() yield from bytes_repr_numpy(obj.to_numpy(), cache) + NUMPY_CHUNK_LEN = 8192