Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion distributed/protocol/pickle.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import logging
import pickle
from pickle import HIGHEST_PROTOCOL

import cloudpickle


HIGHEST_PROTOCOL = pickle.HIGHEST_PROTOCOL

logger = logging.getLogger(__name__)


Expand Down
12 changes: 4 additions & 8 deletions distributed/protocol/tests/test_pickle.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import partial
import gc
from operator import add
import pickle
import weakref
import sys

Expand All @@ -9,11 +10,6 @@
from distributed.protocol import deserialize, serialize
from distributed.protocol.pickle import HIGHEST_PROTOCOL, dumps, loads

try:
from pickle import PickleBuffer
except ImportError:
pass


def test_pickle_data():
data = [1, b"123", "123", [123], {}, set()]
Expand All @@ -29,7 +25,7 @@ def __init__(self, mv):

def __reduce_ex__(self, protocol):
if protocol >= 5:
return MemoryviewHolder, (PickleBuffer(self.mv),)
return MemoryviewHolder, (pickle.PickleBuffer(self.mv),)
else:
return MemoryviewHolder, (self.mv.tobytes(),)

Expand All @@ -42,7 +38,7 @@ def __reduce_ex__(self, protocol):
mvh2 = loads(d, buffers=l)

assert len(l) == 1
assert isinstance(l[0], PickleBuffer)
assert isinstance(l[0], pickle.PickleBuffer)
assert memoryview(l[0]) == mv
else:
mvh2 = loads(dumps(mvh))
Expand Down Expand Up @@ -84,7 +80,7 @@ def test_pickle_numpy():
l = []
d = dumps(x, buffer_callback=l.append)
assert len(l) == 1
assert isinstance(l[0], PickleBuffer)
assert isinstance(l[0], pickle.PickleBuffer)
assert memoryview(l[0]) == memoryview(x)
assert (loads(d, buffers=l) == x).all()

Expand Down