diff --git a/distributed/protocol/numpy.py b/distributed/protocol/numpy.py index 2140c2f0c4e..caeca147169 100644 --- a/distributed/protocol/numpy.py +++ b/distributed/protocol/numpy.py @@ -4,7 +4,7 @@ from .serialize import dask_serialize, dask_deserialize from . import pickle -from ..utils import log_errors +from ..utils import log_errors, nbytes def itemsize(dt): @@ -22,7 +22,10 @@ def itemsize(dt): def serialize_numpy_ndarray(x): if x.dtype.hasobject: header = {"pickle": True} - frames = [pickle.dumps(x)] + frames = [None] + buffer_callback = lambda f: frames.append(memoryview(f)) + frames[0] = pickle.dumps(x, buffer_callback=buffer_callback) + header["lengths"] = tuple(map(nbytes, frames)) return header, frames # We cannot blindly pickle the dtype as some may fail pickling, @@ -96,10 +99,10 @@ def serialize_numpy_ndarray(x): @dask_deserialize.register(np.ndarray) def deserialize_numpy_ndarray(header, frames): with log_errors(): - (frame,) = frames - if header.get("pickle"): - return pickle.loads(frame) + return pickle.loads(frames[0], buffers=frames[1:]) + + (frame,) = frames is_custom, dt = header["dtype"] if is_custom: diff --git a/distributed/protocol/tests/test_numpy.py b/distributed/protocol/tests/test_numpy.py index 830991fd56a..0e299632902 100644 --- a/distributed/protocol/tests/test_numpy.py +++ b/distributed/protocol/tests/test_numpy.py @@ -14,6 +14,7 @@ ) from distributed.protocol.utils import BIG_BYTES_SHARD_SIZE from distributed.protocol.numpy import itemsize +from distributed.protocol.pickle import HIGHEST_PROTOCOL from distributed.protocol.compression import maybe_compress from distributed.system import MEMORY_LIMIT from distributed.utils import tmpfile, nbytes @@ -57,6 +58,7 @@ def test_serialize(): np.array(["abc"], dtype="S3"), np.array(["abc"], dtype="U3"), np.array(["abc"], dtype=object), + np.array([np.arange(3), np.arange(4, 6)], dtype=object), np.ones(shape=(5,), dtype=("f8", 32)), np.ones(shape=(5,), dtype=[("x", "f8", 32)]), np.ones(shape=(5,), dtype=np.dtype([("a", "i1"), ("b", "f8")], align=False)), @@ -79,12 +81,24 @@ def test_dumps_serialize_numpy(x): frames = decompress(header, frames) for frame in frames: assert isinstance(frame, (bytes, memoryview)) + if x.dtype.char == "O" and any(isinstance(e, np.ndarray) for e in x.flat): + if HIGHEST_PROTOCOL >= 5: + assert len(frames) > 1 + else: + assert len(frames) == 1 y = deserialize(header, frames) - np.testing.assert_equal(x, y) + assert x.shape == y.shape + assert x.dtype == y.dtype if x.flags.c_contiguous or x.flags.f_contiguous: assert x.strides == y.strides + if x.dtype.char == "O": + for e_x, e_y in zip(x.flat, y.flat): + np.testing.assert_equal(e_x, e_y) + else: + np.testing.assert_equal(x, y) + @pytest.mark.parametrize( "x", diff --git a/distributed/protocol/tests/test_pickle.py b/distributed/protocol/tests/test_pickle.py index bd784117186..9ee496f5e9f 100644 --- a/distributed/protocol/tests/test_pickle.py +++ b/distributed/protocol/tests/test_pickle.py @@ -74,6 +74,25 @@ def test_pickle_numpy(): assert (loads(dumps(x)) == x).all() assert (deserialize(*serialize(x, serializers=("pickle",))) == x).all() + x = np.array([np.arange(3), np.arange(4, 6)], dtype=object) + x2 = loads(dumps(x)) + assert x.shape == x2.shape + assert x.dtype == x2.dtype + assert x.strides == x2.strides + for e_x, e_x2 in zip(x.flat, x2.flat): + np.testing.assert_equal(e_x, e_x2) + h, f = serialize(x, serializers=("pickle",)) + if HIGHEST_PROTOCOL >= 5: + assert len(f) == 3 + else: + assert len(f) == 1 + x3 = deserialize(h, f) + assert x.shape == x3.shape + assert x.dtype == x3.dtype + assert x.strides == x3.strides + for e_x, e_x3 in zip(x.flat, x3.flat): + np.testing.assert_equal(e_x, e_x3) + if HIGHEST_PROTOCOL >= 5: x = np.ones(5000)