diff --git a/pydra/utils/hash.py b/pydra/utils/hash.py index 3ba3e97b4..816a0e1a2 100644 --- a/pydra/utils/hash.py +++ b/pydra/utils/hash.py @@ -43,6 +43,27 @@ else: HAVE_NUMPY = True +try: + import pandas +except ImportError: + HAVE_PANDAS = False +else: + HAVE_PANDAS = True + +try: + import torch +except ImportError: + HAVE_PYTORCH = False +else: + HAVE_PYTORCH = True + +try: + import tensorflow +except ImportError: + HAVE_TENSORFLOW = False +else: + HAVE_TENSORFLOW = True + __all__ = ( "hash_function", "hash_object", @@ -565,4 +586,28 @@ def bytes_repr_numpy(obj: numpy.ndarray, cache: Cache) -> Iterator[bytes]: yield obj.tobytes(order="C") +if HAVE_PYTORCH: + + @register_serializer(torch.Tensor) + def bytes_repr_torch(obj: torch.Tensor, cache: Cache) -> Iterator[bytes]: + yield f"{obj.__class__.__module__}{obj.__class__.__name__}:".encode() + yield from bytes_repr_numpy(obj.numpy(), cache) + + +if HAVE_TENSORFLOW: + + @register_serializer(tensorflow.Tensor) + def bytes_repr_tensorflow(obj: tensorflow.Tensor, cache: Cache) -> Iterator[bytes]: + yield f"{obj.__class__.__module__}{obj.__class__.__name__}:".encode() + yield from bytes_repr_numpy(obj.numpy(), cache) + + +if HAVE_PANDAS: + + @register_serializer(pandas.DataFrame) + def bytes_repr_pandas(obj: pandas.DataFrame, cache: Cache) -> Iterator[bytes]: + yield f"{obj.__class__.__module__}{obj.__class__.__name__}:".encode() + yield from bytes_repr_numpy(obj.to_numpy(), cache) + + NUMPY_CHUNK_LEN = 8192