diff --git a/distributed/client.py b/distributed/client.py index 7edff901dc9..b0b3620d503 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -2702,9 +2702,9 @@ def get_dataset(self, name, default=no_default, **kwargs): async def _run_on_scheduler(self, function, *args, wait=True, **kwargs): response = await self.scheduler.run_function( - function=dumps(function, protocol=4), - args=dumps(args, protocol=4), - kwargs=dumps(kwargs, protocol=4), + function=dumps(function), + args=dumps(args), + kwargs=dumps(kwargs), wait=wait, ) if response["status"] == "error": @@ -2765,10 +2765,10 @@ async def _run( responses = await self.scheduler.broadcast( msg=dict( op="run", - function=dumps(function, protocol=4), - args=dumps(args, protocol=4), + function=dumps(function), + args=dumps(args), wait=wait, - kwargs=dumps(kwargs, protocol=4), + kwargs=dumps(kwargs), ), workers=workers, nanny=nanny, @@ -4614,7 +4614,7 @@ async def _get_task_stream( async def _register_scheduler_plugin(self, plugin, name, idempotent=False): return await self.scheduler.register_scheduler_plugin( - plugin=dumps(plugin, protocol=4), + plugin=dumps(plugin), name=name, idempotent=idempotent, ) @@ -4670,7 +4670,7 @@ async def _register_worker_plugin(self, plugin=None, name=None, nanny=None): else: method = self.scheduler.register_worker_plugin - responses = await method(plugin=dumps(plugin, protocol=4), name=name) + responses = await method(plugin=dumps(plugin), name=name) for response in responses.values(): if response["status"] == "error": _, exc, tb = clean_exception( diff --git a/distributed/comm/core.py b/distributed/comm/core.py index 8b97d6c832b..7dff6540d76 100644 --- a/distributed/comm/core.py +++ b/distributed/comm/core.py @@ -16,8 +16,8 @@ from distributed.comm import registry from distributed.comm.addressing import parse_address from distributed.metrics import time -from distributed.protocol import pickle from distributed.protocol.compression import get_default_compression +from distributed.protocol.pickle import HIGHEST_PROTOCOL logger = logging.getLogger(__name__) @@ -133,7 +133,7 @@ def handshake_info(): return { "compression": get_default_compression(), "python": tuple(sys.version_info)[:3], - "pickle-protocol": pickle.HIGHEST_PROTOCOL, + "pickle-protocol": HIGHEST_PROTOCOL, } @staticmethod diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index 5f0c3d858c2..58f4a5e5938 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -47,7 +47,7 @@ def dask_dumps(x, context=None): header = { "sub-header": sub_header, "type": type_name, - "type-serialized": pickle.dumps(type(x), protocol=4), + "type-serialized": pickle.dumps(type(x)), "serializer": "dask", } return header, frames @@ -834,7 +834,7 @@ def __init__(self, serializer): def serialize(self, est): header = { "serializer": self.serializer, - "type-serialized": pickle.dumps(type(est), protocol=4), + "type-serialized": pickle.dumps(type(est)), "simple": {}, "complex": {}, } diff --git a/distributed/protocol/tests/test_numpy.py b/distributed/protocol/tests/test_numpy.py index 759c9deccdb..c336eaf7ddb 100644 --- a/distributed/protocol/tests/test_numpy.py +++ b/distributed/protocol/tests/test_numpy.py @@ -19,7 +19,6 @@ ) from distributed.protocol.compression import maybe_compress from distributed.protocol.numpy import itemsize -from distributed.protocol.pickle import HIGHEST_PROTOCOL from distributed.protocol.utils import BIG_BYTES_SHARD_SIZE from distributed.system import MEMORY_LIMIT from distributed.utils import nbytes @@ -87,10 +86,7 @@ def test_dumps_serialize_numpy(x): for frame in frames: assert isinstance(frame, (bytes, memoryview)) if x.dtype.char == "O" and any(isinstance(e, np.ndarray) for e in x.flat): - if HIGHEST_PROTOCOL >= 5: - assert len(frames) > 1 - else: - assert len(frames) == 1 + assert len(frames) > 1 # pickle protocol >= 5 y = deserialize(header, frames) assert x.shape == y.shape diff --git a/distributed/protocol/tests/test_pickle.py b/distributed/protocol/tests/test_pickle.py index 2ab40b5b78b..58dde6c6879 100644 --- a/distributed/protocol/tests/test_pickle.py +++ b/distributed/protocol/tests/test_pickle.py @@ -212,9 +212,9 @@ def test_pickle_by_value_when_registered(): f.write("def myfunc(x):\n return x + 1") import mymodule # noqa - assert dumps( + assert dumps(mymodule.myfunc) == pickle.dumps( mymodule.myfunc, protocol=HIGHEST_PROTOCOL - ) == pickle.dumps(mymodule.myfunc, protocol=HIGHEST_PROTOCOL) + ) cloudpickle.register_pickle_by_value(mymodule) assert len(dumps(mymodule.myfunc)) > len(pickle.dumps(mymodule.myfunc)) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 5342d693e91..e9e463af328 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4742,7 +4742,6 @@ async def remove_worker( last_worker=ws.clean(), allowed_failures=self.allowed_failures, ), - protocol=4, ) r = self.transition( k, @@ -5802,7 +5801,7 @@ async def send_message(addr): elif on_error == "return": return e elif on_error == "return_pickle": - return dumps(e, protocol=4) + return dumps(e) elif on_error == "ignore": return ERROR else: diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index e6bc03d85fb..8f1bfcc9c32 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -479,7 +479,7 @@ async def test_plugin_internal_exception(): async with Worker( s.address, plugins={ - b"corrupting pickle" + pickle.dumps(lambda: None, protocol=4), + b"corrupting pickle" + pickle.dumps(lambda: None), }, ) as w: pass diff --git a/distributed/worker.py b/distributed/worker.py index 90c87f88368..3a6e76fea46 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2904,12 +2904,12 @@ def dumps_function(func) -> bytes: with _cache_lock: result = cache_dumps[func] except KeyError: - result = pickle.dumps(func, protocol=4) + result = pickle.dumps(func) if len(result) < 100000: with _cache_lock: cache_dumps[func] = result except TypeError: # Unhashable function - result = pickle.dumps(func, protocol=4) + result = pickle.dumps(func) return result @@ -2949,7 +2949,7 @@ def dumps_task(task): def warn_dumps(obj, dumps=pickle.dumps, limit=1e6): """Dump an object to bytes, warn if those bytes are large""" - b = dumps(obj, protocol=4) + b = dumps(obj) if not _warn_dumps_warned[0] and len(b) > limit: _warn_dumps_warned[0] = True s = str(obj) diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 8eb0f6480ec..fc074ac7bea 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -1762,11 +1762,11 @@ def _get_task_finished_msg( typ = ts.type = type(value) del value try: - typ_serialized = pickle.dumps(typ, protocol=4) + typ_serialized = pickle.dumps(typ) except Exception: # Some types fail pickling (example: _thread.lock objects), # send their name as a best effort. - typ_serialized = pickle.dumps(typ.__name__, protocol=4) + typ_serialized = pickle.dumps(typ.__name__) return TaskFinishedMsg( key=ts.key, nbytes=ts.nbytes,