Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] zero-copy serializer for pytorch #12344

Merged
merged 6 commits into from
Nov 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/ray/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
MessagePackSerializedObject,
RawSerializedObject,
)
from ray import serialization_addons

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -155,6 +156,7 @@ def object_ref_reducer(obj):
# Because objects have default __reduce__ method, we only need to
# treat ObjectRef specifically.
self._register_cloudpickle_reducer(ray.ObjectRef, object_ref_reducer)
serialization_addons.apply(self)

def _register_cloudpickle_reducer(self, cls, reducer):
pickle.CloudPickler.dispatch[cls] = reducer
Expand Down
72 changes: 72 additions & 0 deletions python/ray/serialization_addons.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""
This module is intended for implementing internal serializers for some
site packages.
"""

import warnings

try:
import torch

_TORCH_WARNING_FILTER_ACTIVATE = True

class _TorchTensorReducingHelper:
def __init__(self, tensor):
self.tensor = tensor

@classmethod
def rebuild_tensor(cls, rebuild_func, device, ndarray, params):
global _TORCH_WARNING_FILTER_ACTIVATE
# filtering warning messages would be the bottleneck for
# deserializing torch tensors. Since the warning only prompts once,
# we would only deal with it for the first time.
if _TORCH_WARNING_FILTER_ACTIVATE:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=UserWarning,
message="The given NumPy array is not writeable")
_tensor = torch.from_numpy(ndarray)
_TORCH_WARNING_FILTER_ACTIVATE = False
else:
_tensor = torch.from_numpy(ndarray)
if device != torch.device("cpu"):
_tensor = _tensor.to(device)
tensor = rebuild_func(_tensor.storage(), *params)
return cls(tensor)

@classmethod
def rebuild_sparse_tensor(cls, rebuild_func, content):
tensor = rebuild_func(*content)
return cls(tensor)

def __reduce_ex__(self, protocol):
_rebuild_func, content = self.tensor.__reduce_ex__(protocol)
if self.tensor.is_sparse:
# Torch will help us reduce the sparse tensor into
# several continuous tensors.
return self.rebuild_sparse_tensor, (_rebuild_func, content)
# By only replacing the storage with a numpy array, we can reuse
# zero-copy serialization while keeping all other params of the
# torch tensor.
return self.rebuild_tensor, (_rebuild_func, self.tensor.device,
self.tensor.detach().cpu().numpy(),
content[1:])

def _unwrap_tensor(s):
return s.tensor

def torch_tensor_reducer(tensor):
return _unwrap_tensor, (_TorchTensorReducingHelper(tensor), )

except ImportError:
pass


def apply(serialization_context):
try:
import torch
serialization_context._register_cloudpickle_reducer(
torch.Tensor, torch_tensor_reducer)
except ImportError:
pass
26 changes: 25 additions & 1 deletion python/ray/tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ def __del__(self):
assert new_obj() is None


def test_buffer_alignment():
def test_buffer_alignment(ray_start_shared_local_modes):
# Deserialized large numpy arrays should be 64-byte aligned.
x = np.random.normal(size=(10, 20, 30))
y = ray.get(ray.put(x))
Expand All @@ -568,6 +568,30 @@ def test_buffer_alignment():
assert y.ctypes.data % 8 == 0


def test_pytorch_tensor_zerocopy_serialization(ray_start_shared_local_modes):
import torch
# test dense tensor
tensor = torch.rand(32, 3, 64, 64)
ref = ray.put(tensor)
tensor_1, tensor_2 = ray.get([ref] * 2)
assert tensor_1.data_ptr() == tensor_2.data_ptr()

# test sparse tensor
i = torch.arange(0, 1024 * 1024, 4).view(1, -1)
v = torch.rand(1024 * 1024 // 4)
k = torch.sparse_coo_tensor(i, v, size=(1024 * 1024, ))
ref = ray.put(k)
tensor_1, tensor_2 = ray.get([ref] * 2)
assert tensor_1._indices().data_ptr() == tensor_2._indices().data_ptr()
assert tensor_1._values().data_ptr() == tensor_2._values().data_ptr()

# test attributes
tensor = torch.rand(4).requires_grad_(True)
ref = ray.put(tensor)
tensor = ray.get(ref)
assert tensor.requires_grad


if __name__ == "__main__":
import pytest
sys.exit(pytest.main(["-v", __file__]))