From 069480b351cb62c991dc3602a5ae0838891f104d Mon Sep 17 00:00:00 2001 From: Siyuan Date: Tue, 24 Nov 2020 01:25:29 -0800 Subject: [PATCH 1/6] zero-copy serializer for pytorch --- python/ray/serialization.py | 2 ++ python/ray/serialization_addons.py | 53 ++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+) create mode 100644 python/ray/serialization_addons.py diff --git a/python/ray/serialization.py b/python/ray/serialization.py index f85b07afa5bee..ef1ec50f271c6 100644 --- a/python/ray/serialization.py +++ b/python/ray/serialization.py @@ -26,6 +26,7 @@ MessagePackSerializedObject, RawSerializedObject, ) +from ray import serialization_addons logger = logging.getLogger(__name__) @@ -155,6 +156,7 @@ def object_ref_reducer(obj): # Because objects have default __reduce__ method, we only need to # treat ObjectRef specifically. self._register_cloudpickle_reducer(ray.ObjectRef, object_ref_reducer) + serialization_addons.apply(self) def _register_cloudpickle_reducer(self, cls, reducer): pickle.CloudPickler.dispatch[cls] = reducer diff --git a/python/ray/serialization_addons.py b/python/ray/serialization_addons.py new file mode 100644 index 0000000000000..c5856d3aa350c --- /dev/null +++ b/python/ray/serialization_addons.py @@ -0,0 +1,53 @@ +""" +This module is intended for implementing internal serializers for some +site packages. +""" + +import warnings + +try: + import torch + + + class _TorchTensorReducingHelper: + def __init__(self, tensor): + self.tensor = tensor + + @classmethod + def rebuild_tensor(cls, _rebuild_func, ndarray, params): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning, + message="The given NumPy array is not writeable") + storage = torch.from_numpy(ndarray).storage() + tensor = _rebuild_func(storage, *params) + return cls(tensor) + + def __reduce_ex__(self, protocol): + if self.tensor.is_sparse: + # Torch will help us reduce the sparse tensor into + # several continuous tensors. + return _TorchTensorReducingHelper, (self.tensor,) + # By only replacing the storage with a numpy array, we can reuse + # zero-copy serialization while keeping all other params of the torch tensor. + _rebuild_func, content = self.tensor.__reduce_ex__(protocol) + return self.rebuild_tensor, (_rebuild_func, self.tensor.numpy(), content[1:]) + + + def _unwrap_tensor(s): + return s.tensor + + + def torch_tensor_reducer(tensor): + return _unwrap_tensor, (_TorchTensorReducingHelper(tensor),) + +except ImportError: + pass + + +def apply(serialization_context): + try: + import torch + serialization_context._register_cloudpickle_reducer( + torch.Tensor, torch_tensor_reducer) + except ImportError: + pass From b3cb60c16e84adb6f3bffbe9e89260ea70378d65 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Tue, 24 Nov 2020 01:51:07 -0800 Subject: [PATCH 2/6] address possible bottleneck --- python/ray/serialization_addons.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/python/ray/serialization_addons.py b/python/ray/serialization_addons.py index c5856d3aa350c..ac9a6a1c7472d 100644 --- a/python/ray/serialization_addons.py +++ b/python/ray/serialization_addons.py @@ -8,6 +8,8 @@ try: import torch + _TORCH_WARNING_FILTER_ACTIVATE = True + class _TorchTensorReducingHelper: def __init__(self, tensor): @@ -15,9 +17,17 @@ def __init__(self, tensor): @classmethod def rebuild_tensor(cls, _rebuild_func, ndarray, params): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=UserWarning, - message="The given NumPy array is not writeable") + global _TORCH_WARNING_FILTER_ACTIVATE + # filtering warning messages would be the bottleneck for + # deserializing torch tensors. Since the warning only prompts once, + # we would only deal with it for the first time. + if _TORCH_WARNING_FILTER_ACTIVATE: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning, + message="The given NumPy array is not writeable") + storage = torch.from_numpy(ndarray).storage() + _TORCH_WARNING_FILTER_ACTIVATE = False + else: storage = torch.from_numpy(ndarray).storage() tensor = _rebuild_func(storage, *params) return cls(tensor) From 19098cc757288e79daebcbed8bfc9a268bcd2c89 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Tue, 24 Nov 2020 11:42:22 -0800 Subject: [PATCH 3/6] lint --- python/ray/serialization_addons.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/python/ray/serialization_addons.py b/python/ray/serialization_addons.py index ac9a6a1c7472d..11cd1ae84ca27 100644 --- a/python/ray/serialization_addons.py +++ b/python/ray/serialization_addons.py @@ -10,7 +10,6 @@ _TORCH_WARNING_FILTER_ACTIVATE = True - class _TorchTensorReducingHelper: def __init__(self, tensor): self.tensor = tensor @@ -23,8 +22,10 @@ def rebuild_tensor(cls, _rebuild_func, ndarray, params): # we would only deal with it for the first time. if _TORCH_WARNING_FILTER_ACTIVATE: with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=UserWarning, - message="The given NumPy array is not writeable") + warnings.filterwarnings( + "ignore", + category=UserWarning, + message="The given NumPy array is not writeable") storage = torch.from_numpy(ndarray).storage() _TORCH_WARNING_FILTER_ACTIVATE = False else: @@ -36,19 +37,19 @@ def __reduce_ex__(self, protocol): if self.tensor.is_sparse: # Torch will help us reduce the sparse tensor into # several continuous tensors. - return _TorchTensorReducingHelper, (self.tensor,) + return _TorchTensorReducingHelper, (self.tensor, ) # By only replacing the storage with a numpy array, we can reuse - # zero-copy serialization while keeping all other params of the torch tensor. + # zero-copy serialization while keeping all other params of the + # torch tensor. _rebuild_func, content = self.tensor.__reduce_ex__(protocol) - return self.rebuild_tensor, (_rebuild_func, self.tensor.numpy(), content[1:]) - + return self.rebuild_tensor, (_rebuild_func, self.tensor.numpy(), + content[1:]) def _unwrap_tensor(s): return s.tensor - def torch_tensor_reducer(tensor): - return _unwrap_tensor, (_TorchTensorReducingHelper(tensor),) + return _unwrap_tensor, (_TorchTensorReducingHelper(tensor), ) except ImportError: pass From 731d9438911f44fc1e0c5e61c3a51b3bb7a0abea Mon Sep 17 00:00:00 2001 From: Siyuan Date: Tue, 24 Nov 2020 19:51:47 -0800 Subject: [PATCH 4/6] add tests --- python/ray/serialization_addons.py | 13 +++++++++---- python/ray/tests/test_serialization.py | 18 +++++++++++++++++- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/python/ray/serialization_addons.py b/python/ray/serialization_addons.py index 11cd1ae84ca27..441b0fe435ac4 100644 --- a/python/ray/serialization_addons.py +++ b/python/ray/serialization_addons.py @@ -15,7 +15,7 @@ def __init__(self, tensor): self.tensor = tensor @classmethod - def rebuild_tensor(cls, _rebuild_func, ndarray, params): + def rebuild_tensor(cls, rebuild_func, ndarray, params): global _TORCH_WARNING_FILTER_ACTIVATE # filtering warning messages would be the bottleneck for # deserializing torch tensors. Since the warning only prompts once, @@ -30,18 +30,23 @@ def rebuild_tensor(cls, _rebuild_func, ndarray, params): _TORCH_WARNING_FILTER_ACTIVATE = False else: storage = torch.from_numpy(ndarray).storage() - tensor = _rebuild_func(storage, *params) + tensor = rebuild_func(storage, *params) + return cls(tensor) + + @classmethod + def rebuild_sparse_tensor(cls, rebuild_func, content): + tensor = rebuild_func(*content) return cls(tensor) def __reduce_ex__(self, protocol): + _rebuild_func, content = self.tensor.__reduce_ex__(protocol) if self.tensor.is_sparse: # Torch will help us reduce the sparse tensor into # several continuous tensors. - return _TorchTensorReducingHelper, (self.tensor, ) + return self.rebuild_sparse_tensor, (_rebuild_func, content) # By only replacing the storage with a numpy array, we can reuse # zero-copy serialization while keeping all other params of the # torch tensor. - _rebuild_func, content = self.tensor.__reduce_ex__(protocol) return self.rebuild_tensor, (_rebuild_func, self.tensor.numpy(), content[1:]) diff --git a/python/ray/tests/test_serialization.py b/python/ray/tests/test_serialization.py index 35b6e09fedbb3..a2837e1da13ed 100644 --- a/python/ray/tests/test_serialization.py +++ b/python/ray/tests/test_serialization.py @@ -543,7 +543,7 @@ def __del__(self): assert new_obj() is None -def test_buffer_alignment(): +def test_buffer_alignment(ray_start_shared_local_modes): # Deserialized large numpy arrays should be 64-byte aligned. x = np.random.normal(size=(10, 20, 30)) y = ray.get(ray.put(x)) @@ -568,6 +568,22 @@ def test_buffer_alignment(): assert y.ctypes.data % 8 == 0 +def test_pytorch_tensor_zerocopy_serialization(ray_start_shared_local_modes): + import torch + tensor = torch.rand(32, 3, 64, 64) + ref = ray.put(tensor) + tensor_1, tensor_2 = ray.get([ref] * 2) + assert tensor_1.data_ptr() == tensor_2.data_ptr() + + i = torch.arange(0, 1024 * 1024, 4).view(1, -1) + v = torch.rand(1024 * 1024 // 4) + k = torch.sparse_coo_tensor(i, v, size=(1024 * 1024, )) + ref = ray.put(k) + tensor_1, tensor_2 = ray.get([ref] * 2) + assert tensor_1._indices().data_ptr() == tensor_2._indices().data_ptr() + assert tensor_1._values().data_ptr() == tensor_2._values().data_ptr() + + if __name__ == "__main__": import pytest sys.exit(pytest.main(["-v", __file__])) From 40d0fb552ed379870b34613455e9b17cd9aad8b2 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Wed, 25 Nov 2020 14:49:23 -0800 Subject: [PATCH 5/6] more tests & device support --- python/ray/serialization_addons.py | 13 ++++++++----- python/ray/tests/test_serialization.py | 8 ++++++++ 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/python/ray/serialization_addons.py b/python/ray/serialization_addons.py index 441b0fe435ac4..7ab085a41f85a 100644 --- a/python/ray/serialization_addons.py +++ b/python/ray/serialization_addons.py @@ -15,7 +15,7 @@ def __init__(self, tensor): self.tensor = tensor @classmethod - def rebuild_tensor(cls, rebuild_func, ndarray, params): + def rebuild_tensor(cls, rebuild_func, device, ndarray, params): global _TORCH_WARNING_FILTER_ACTIVATE # filtering warning messages would be the bottleneck for # deserializing torch tensors. Since the warning only prompts once, @@ -26,11 +26,13 @@ def rebuild_tensor(cls, rebuild_func, ndarray, params): "ignore", category=UserWarning, message="The given NumPy array is not writeable") - storage = torch.from_numpy(ndarray).storage() + _tensor = torch.from_numpy(ndarray) _TORCH_WARNING_FILTER_ACTIVATE = False else: - storage = torch.from_numpy(ndarray).storage() - tensor = rebuild_func(storage, *params) + _tensor = torch.from_numpy(ndarray) + if device != torch.device('cpu'): + _tensor = _tensor.to(device) + tensor = rebuild_func(_tensor.storage(), *params) return cls(tensor) @classmethod @@ -47,7 +49,8 @@ def __reduce_ex__(self, protocol): # By only replacing the storage with a numpy array, we can reuse # zero-copy serialization while keeping all other params of the # torch tensor. - return self.rebuild_tensor, (_rebuild_func, self.tensor.numpy(), + return self.rebuild_tensor, (_rebuild_func, self.tensor.device, + self.tensor.detach().cpu().numpy(), content[1:]) def _unwrap_tensor(s): diff --git a/python/ray/tests/test_serialization.py b/python/ray/tests/test_serialization.py index a2837e1da13ed..500b1ed84df99 100644 --- a/python/ray/tests/test_serialization.py +++ b/python/ray/tests/test_serialization.py @@ -570,11 +570,13 @@ def test_buffer_alignment(ray_start_shared_local_modes): def test_pytorch_tensor_zerocopy_serialization(ray_start_shared_local_modes): import torch + # test dense tensor tensor = torch.rand(32, 3, 64, 64) ref = ray.put(tensor) tensor_1, tensor_2 = ray.get([ref] * 2) assert tensor_1.data_ptr() == tensor_2.data_ptr() + # test sparse tensor i = torch.arange(0, 1024 * 1024, 4).view(1, -1) v = torch.rand(1024 * 1024 // 4) k = torch.sparse_coo_tensor(i, v, size=(1024 * 1024, )) @@ -583,6 +585,12 @@ def test_pytorch_tensor_zerocopy_serialization(ray_start_shared_local_modes): assert tensor_1._indices().data_ptr() == tensor_2._indices().data_ptr() assert tensor_1._values().data_ptr() == tensor_2._values().data_ptr() + # test attributes + tensor = torch.rand(4).requires_grad_(True) + ref = ray.put(tensor) + tensor = ray.get(ref) + assert tensor.requires_grad + if __name__ == "__main__": import pytest From 7ade8c7379cf98b95e75dd1d858f1ca0fb8f9e17 Mon Sep 17 00:00:00 2001 From: Siyuan Date: Wed, 25 Nov 2020 14:53:24 -0800 Subject: [PATCH 6/6] lint --- python/ray/serialization_addons.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/ray/serialization_addons.py b/python/ray/serialization_addons.py index 7ab085a41f85a..3d57a91372965 100644 --- a/python/ray/serialization_addons.py +++ b/python/ray/serialization_addons.py @@ -30,7 +30,7 @@ def rebuild_tensor(cls, rebuild_func, device, ndarray, params): _TORCH_WARNING_FILTER_ACTIVATE = False else: _tensor = torch.from_numpy(ndarray) - if device != torch.device('cpu'): + if device != torch.device("cpu"): _tensor = _tensor.to(device) tensor = rebuild_func(_tensor.storage(), *params) return cls(tensor)