diff --git a/distributed/protocol/pickle.py b/distributed/protocol/pickle.py index 9774202e4fe..6e168947d48 100644 --- a/distributed/protocol/pickle.py +++ b/distributed/protocol/pickle.py @@ -1,10 +1,11 @@ import logging import pickle -from pickle import HIGHEST_PROTOCOL import cloudpickle +HIGHEST_PROTOCOL = pickle.HIGHEST_PROTOCOL + logger = logging.getLogger(__name__) diff --git a/distributed/protocol/tests/test_pickle.py b/distributed/protocol/tests/test_pickle.py index f4a4ec7f8ee..bd784117186 100644 --- a/distributed/protocol/tests/test_pickle.py +++ b/distributed/protocol/tests/test_pickle.py @@ -1,6 +1,7 @@ from functools import partial import gc from operator import add +import pickle import weakref import sys @@ -9,11 +10,6 @@ from distributed.protocol import deserialize, serialize from distributed.protocol.pickle import HIGHEST_PROTOCOL, dumps, loads -try: - from pickle import PickleBuffer -except ImportError: - pass - def test_pickle_data(): data = [1, b"123", "123", [123], {}, set()] @@ -29,7 +25,7 @@ def __init__(self, mv): def __reduce_ex__(self, protocol): if protocol >= 5: - return MemoryviewHolder, (PickleBuffer(self.mv),) + return MemoryviewHolder, (pickle.PickleBuffer(self.mv),) else: return MemoryviewHolder, (self.mv.tobytes(),) @@ -42,7 +38,7 @@ def __reduce_ex__(self, protocol): mvh2 = loads(d, buffers=l) assert len(l) == 1 - assert isinstance(l[0], PickleBuffer) + assert isinstance(l[0], pickle.PickleBuffer) assert memoryview(l[0]) == mv else: mvh2 = loads(dumps(mvh)) @@ -84,7 +80,7 @@ def test_pickle_numpy(): l = [] d = dumps(x, buffer_callback=l.append) assert len(l) == 1 - assert isinstance(l[0], PickleBuffer) + assert isinstance(l[0], pickle.PickleBuffer) assert memoryview(l[0]) == memoryview(x) assert (loads(d, buffers=l) == x).all()