Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2702,9 +2702,9 @@ def get_dataset(self, name, default=no_default, **kwargs):

async def _run_on_scheduler(self, function, *args, wait=True, **kwargs):
response = await self.scheduler.run_function(
function=dumps(function, protocol=4),
args=dumps(args, protocol=4),
kwargs=dumps(kwargs, protocol=4),
function=dumps(function),
args=dumps(args),
kwargs=dumps(kwargs),
wait=wait,
)
if response["status"] == "error":
Expand Down Expand Up @@ -2765,10 +2765,10 @@ async def _run(
responses = await self.scheduler.broadcast(
msg=dict(
op="run",
function=dumps(function, protocol=4),
args=dumps(args, protocol=4),
function=dumps(function),
args=dumps(args),
wait=wait,
kwargs=dumps(kwargs, protocol=4),
kwargs=dumps(kwargs),
),
workers=workers,
nanny=nanny,
Expand Down Expand Up @@ -4614,7 +4614,7 @@ async def _get_task_stream(

async def _register_scheduler_plugin(self, plugin, name, idempotent=False):
return await self.scheduler.register_scheduler_plugin(
plugin=dumps(plugin, protocol=4),
plugin=dumps(plugin),
name=name,
idempotent=idempotent,
)
Expand Down Expand Up @@ -4670,7 +4670,7 @@ async def _register_worker_plugin(self, plugin=None, name=None, nanny=None):
else:
method = self.scheduler.register_worker_plugin

responses = await method(plugin=dumps(plugin, protocol=4), name=name)
responses = await method(plugin=dumps(plugin), name=name)
for response in responses.values():
if response["status"] == "error":
_, exc, tb = clean_exception(
Expand Down
4 changes: 2 additions & 2 deletions distributed/comm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from distributed.comm import registry
from distributed.comm.addressing import parse_address
from distributed.metrics import time
from distributed.protocol import pickle
from distributed.protocol.compression import get_default_compression
from distributed.protocol.pickle import HIGHEST_PROTOCOL

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -133,7 +133,7 @@ def handshake_info():
return {
"compression": get_default_compression(),
"python": tuple(sys.version_info)[:3],
"pickle-protocol": pickle.HIGHEST_PROTOCOL,
"pickle-protocol": HIGHEST_PROTOCOL,
}

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions distributed/protocol/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def dask_dumps(x, context=None):
header = {
"sub-header": sub_header,
"type": type_name,
"type-serialized": pickle.dumps(type(x), protocol=4),
"type-serialized": pickle.dumps(type(x)),
"serializer": "dask",
}
return header, frames
Expand Down Expand Up @@ -834,7 +834,7 @@ def __init__(self, serializer):
def serialize(self, est):
header = {
"serializer": self.serializer,
"type-serialized": pickle.dumps(type(est), protocol=4),
"type-serialized": pickle.dumps(type(est)),
"simple": {},
"complex": {},
}
Expand Down
6 changes: 1 addition & 5 deletions distributed/protocol/tests/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
)
from distributed.protocol.compression import maybe_compress
from distributed.protocol.numpy import itemsize
from distributed.protocol.pickle import HIGHEST_PROTOCOL
from distributed.protocol.utils import BIG_BYTES_SHARD_SIZE
from distributed.system import MEMORY_LIMIT
from distributed.utils import nbytes
Expand Down Expand Up @@ -87,10 +86,7 @@ def test_dumps_serialize_numpy(x):
for frame in frames:
assert isinstance(frame, (bytes, memoryview))
if x.dtype.char == "O" and any(isinstance(e, np.ndarray) for e in x.flat):
if HIGHEST_PROTOCOL >= 5:
assert len(frames) > 1
else:
assert len(frames) == 1
assert len(frames) > 1 # pickle protocol >= 5
y = deserialize(header, frames)

assert x.shape == y.shape
Expand Down
4 changes: 2 additions & 2 deletions distributed/protocol/tests/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,9 @@ def test_pickle_by_value_when_registered():
f.write("def myfunc(x):\n return x + 1")
import mymodule # noqa

assert dumps(
assert dumps(mymodule.myfunc) == pickle.dumps(
mymodule.myfunc, protocol=HIGHEST_PROTOCOL
) == pickle.dumps(mymodule.myfunc, protocol=HIGHEST_PROTOCOL)
)
cloudpickle.register_pickle_by_value(mymodule)
assert len(dumps(mymodule.myfunc)) > len(pickle.dumps(mymodule.myfunc))

Expand Down
3 changes: 1 addition & 2 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4742,7 +4742,6 @@ async def remove_worker(
last_worker=ws.clean(),
allowed_failures=self.allowed_failures,
),
protocol=4,
)
r = self.transition(
k,
Expand Down Expand Up @@ -5802,7 +5801,7 @@ async def send_message(addr):
elif on_error == "return":
return e
elif on_error == "return_pickle":
return dumps(e, protocol=4)
return dumps(e)
elif on_error == "ignore":
return ERROR
else:
Expand Down
2 changes: 1 addition & 1 deletion distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ async def test_plugin_internal_exception():
async with Worker(
s.address,
plugins={
b"corrupting pickle" + pickle.dumps(lambda: None, protocol=4),
b"corrupting pickle" + pickle.dumps(lambda: None),
},
) as w:
pass
Expand Down
6 changes: 3 additions & 3 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2904,12 +2904,12 @@ def dumps_function(func) -> bytes:
with _cache_lock:
result = cache_dumps[func]
except KeyError:
result = pickle.dumps(func, protocol=4)
result = pickle.dumps(func)
if len(result) < 100000:
with _cache_lock:
cache_dumps[func] = result
except TypeError: # Unhashable function
result = pickle.dumps(func, protocol=4)
result = pickle.dumps(func)
return result


Expand Down Expand Up @@ -2949,7 +2949,7 @@ def dumps_task(task):

def warn_dumps(obj, dumps=pickle.dumps, limit=1e6):
"""Dump an object to bytes, warn if those bytes are large"""
b = dumps(obj, protocol=4)
b = dumps(obj)
if not _warn_dumps_warned[0] and len(b) > limit:
_warn_dumps_warned[0] = True
s = str(obj)
Expand Down
4 changes: 2 additions & 2 deletions distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1762,11 +1762,11 @@ def _get_task_finished_msg(
typ = ts.type = type(value)
del value
try:
typ_serialized = pickle.dumps(typ, protocol=4)
typ_serialized = pickle.dumps(typ)
except Exception:
# Some types fail pickling (example: _thread.lock objects),
# send their name as a best effort.
typ_serialized = pickle.dumps(typ.__name__, protocol=4)
typ_serialized = pickle.dumps(typ.__name__)
return TaskFinishedMsg(
key=ts.key,
nbytes=ts.nbytes,
Expand Down