diff --git a/distributed/protocol/pickle.py b/distributed/protocol/pickle.py index 9a1f135444f..9774202e4fe 100644 --- a/distributed/protocol/pickle.py +++ b/distributed/protocol/pickle.py @@ -1,5 +1,6 @@ import logging import pickle +from pickle import HIGHEST_PROTOCOL import cloudpickle @@ -23,36 +24,46 @@ def _always_use_pickle_for(x): return False -def dumps(x): +def dumps(x, *, buffer_callback=None): """ Manage between cloudpickle and pickle 1. Try pickle 2. If it is short then check if it contains __main__ 3. If it is long, then first check type, then check __main__ """ + buffers = [] + dump_kwargs = {"protocol": HIGHEST_PROTOCOL} + if HIGHEST_PROTOCOL >= 5 and buffer_callback is not None: + dump_kwargs["buffer_callback"] = buffers.append try: - result = pickle.dumps(x, protocol=pickle.HIGHEST_PROTOCOL) + buffers.clear() + result = pickle.dumps(x, **dump_kwargs) if len(result) < 1000: if b"__main__" in result: - return cloudpickle.dumps(x, protocol=pickle.HIGHEST_PROTOCOL) - else: - return result - else: - if _always_use_pickle_for(x) or b"__main__" not in result: - return result - else: - return cloudpickle.dumps(x, protocol=pickle.HIGHEST_PROTOCOL) + buffers.clear() + result = cloudpickle.dumps(x, **dump_kwargs) + elif not _always_use_pickle_for(x) and b"__main__" in result: + buffers.clear() + result = cloudpickle.dumps(x, **dump_kwargs) except Exception: try: - return cloudpickle.dumps(x, protocol=pickle.HIGHEST_PROTOCOL) + buffers.clear() + result = cloudpickle.dumps(x, **dump_kwargs) except Exception as e: logger.info("Failed to serialize %s. Exception: %s", x, e) raise + if buffer_callback is not None: + for b in buffers: + buffer_callback(b) + return result -def loads(x): +def loads(x, *, buffers=()): try: - return pickle.loads(x) + if buffers: + return pickle.loads(x, buffers=buffers) + else: + return pickle.loads(x) except Exception: logger.info("Failed to deserialize %s", x[:10000], exc_info=True) raise diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index 4d02bc65207..e4fba2b7ba9 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -52,11 +52,16 @@ def dask_loads(header, frames): def pickle_dumps(x): - return {"serializer": "pickle"}, [pickle.dumps(x)] + header = {"serializer": "pickle"} + frames = [None] + buffer_callback = lambda f: frames.append(memoryview(f)) + frames[0] = pickle.dumps(x, buffer_callback=buffer_callback) + return header, frames def pickle_loads(header, frames): - return pickle.loads(b"".join(frames)) + x, buffers = frames[0], frames[1:] + return pickle.loads(x, buffers=buffers) def msgpack_dumps(x): diff --git a/distributed/protocol/tests/test_pickle.py b/distributed/protocol/tests/test_pickle.py index 681992ef844..f4a4ec7f8ee 100644 --- a/distributed/protocol/tests/test_pickle.py +++ b/distributed/protocol/tests/test_pickle.py @@ -6,22 +6,93 @@ import pytest -from distributed.protocol.pickle import dumps, loads +from distributed.protocol import deserialize, serialize +from distributed.protocol.pickle import HIGHEST_PROTOCOL, dumps, loads + +try: + from pickle import PickleBuffer +except ImportError: + pass def test_pickle_data(): data = [1, b"123", "123", [123], {}, set()] for d in data: assert loads(dumps(d)) == d + assert deserialize(*serialize(d, serializers=("pickle",))) == d + + +def test_pickle_out_of_band(): + class MemoryviewHolder: + def __init__(self, mv): + self.mv = memoryview(mv) + + def __reduce_ex__(self, protocol): + if protocol >= 5: + return MemoryviewHolder, (PickleBuffer(self.mv),) + else: + return MemoryviewHolder, (self.mv.tobytes(),) + + mv = memoryview(b"123") + mvh = MemoryviewHolder(mv) + + if HIGHEST_PROTOCOL >= 5: + l = [] + d = dumps(mvh, buffer_callback=l.append) + mvh2 = loads(d, buffers=l) + + assert len(l) == 1 + assert isinstance(l[0], PickleBuffer) + assert memoryview(l[0]) == mv + else: + mvh2 = loads(dumps(mvh)) + + assert isinstance(mvh2, MemoryviewHolder) + assert isinstance(mvh2.mv, memoryview) + assert mvh2.mv == mv + + h, f = serialize(mvh, serializers=("pickle",)) + mvh3 = deserialize(h, f) + + assert isinstance(mvh3, MemoryviewHolder) + assert isinstance(mvh3.mv, memoryview) + assert mvh3.mv == mv + + if HIGHEST_PROTOCOL >= 5: + assert len(f) == 2 + assert isinstance(f[0], bytes) + assert isinstance(f[1], memoryview) + assert f[1] == mv + else: + assert len(f) == 1 + assert isinstance(f[0], bytes) def test_pickle_numpy(): np = pytest.importorskip("numpy") x = np.ones(5) assert (loads(dumps(x)) == x).all() + assert (deserialize(*serialize(x, serializers=("pickle",))) == x).all() x = np.ones(5000) assert (loads(dumps(x)) == x).all() + assert (deserialize(*serialize(x, serializers=("pickle",))) == x).all() + + if HIGHEST_PROTOCOL >= 5: + x = np.ones(5000) + + l = [] + d = dumps(x, buffer_callback=l.append) + assert len(l) == 1 + assert isinstance(l[0], PickleBuffer) + assert memoryview(l[0]) == memoryview(x) + assert (loads(d, buffers=l) == x).all() + + h, f = serialize(x, serializers=("pickle",)) + assert len(f) == 2 + assert isinstance(f[0], bytes) + assert isinstance(f[1], memoryview) + assert (deserialize(h, f) == x).all() @pytest.mark.xfail( @@ -45,10 +116,17 @@ def funcs(): for func in funcs(): wr = weakref.ref(func) + func2 = loads(dumps(func)) wr2 = weakref.ref(func2) assert func2(1) == func(1) - del func, func2 + + func3 = deserialize(*serialize(func, serializers=("pickle",))) + wr3 = weakref.ref(func3) + assert func3(1) == func(1) + + del func, func2, func3 gc.collect() assert wr() is None assert wr2() is None + assert wr3() is None diff --git a/requirements.txt b/requirements.txt index b0d20cdb1eb..95a681d66c2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ click >= 6.6 -cloudpickle >= 0.2.2 +cloudpickle >= 1.3.0 contextvars;python_version<'3.7' dask >= 2.9.0 msgpack >= 0.6.0