Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 61 additions & 4 deletions distributed/protocol/pickle.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from __future__ import print_function, division, absolute_import

import inspect
from io import BytesIO
import logging
import pickle

import cloudpickle

from tornado import gen

from ..utils import ignoring

logger = logging.getLogger(__file__)
Expand All @@ -19,6 +23,59 @@
pickle_types = tuple(pickle_types)


@gen.coroutine
def _tornado_coroutine_sample():
yield

def is_tornado_coroutine(func):
"""
Return whether *func* is a Tornado coroutine function.
Running coroutines are not supported.
"""
return func.__code__ is _tornado_coroutine_sample.__code__
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How reliable is this check? What sort of change in @gen.coroutine would break this is check?

Copy link
Member Author

@pitrou pitrou Nov 15, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would need gen.coroutine to return something else than a wrapper function. Apparently returning a wrapper function has been the case since the beginning, see tornadoweb/tornado@58b0dab


def _rebuild_tornado_coroutine(func):
from tornado import gen
return gen.coroutine(func)

def _get_wrapped_function(func):
try:
return func.__wrapped__
except AttributeError:
pass
# On old Pythons, functools.wraps() doesn't set the __wrapped__
# attribute. Hack around it by inspecting captured variables.
functions = []
for cell in func.__closure__:
with ignoring(ValueError):
v = cell.cell_contents
if inspect.isfunction(v):
functions.append(v)
if len(functions) != 1:
raise RuntimeError("failed to unwrap Tornado coroutine %s: "
"%d candidates found" % (func, len(functions)))
return functions[0]


class ExtendedPickler(cloudpickle.CloudPickler):
"""Extended Pickler class with support for Tornado coroutines.
"""

def save_function_tuple(self, func):
if is_tornado_coroutine(func):
self.save_reduce(_rebuild_tornado_coroutine,
(_get_wrapped_function(func),),
obj=func)
else:
cloudpickle.CloudPickler.save_function_tuple(self, func)


def extended_dumps(obj, protocol=2):
with BytesIO() as bio:
ExtendedPickler(bio, protocol).dump(obj)
return bio.getvalue()


def dumps(x):
""" Manage between cloudpickle and pickle

Expand All @@ -30,17 +87,17 @@ def dumps(x):
result = pickle.dumps(x, protocol=pickle.HIGHEST_PROTOCOL)
if len(result) < 1000:
if b'__main__' in result:
return cloudpickle.dumps(x, protocol=pickle.HIGHEST_PROTOCOL)
return extended_dumps(x, protocol=pickle.HIGHEST_PROTOCOL)
else:
return result
else:
if isinstance(x, pickle_types) or b'__main__' not in result:
return result
else:
return cloudpickle.dumps(x, protocol=pickle.HIGHEST_PROTOCOL)
except:
return extended_dumps(x, protocol=pickle.HIGHEST_PROTOCOL)
except Exception:
try:
return cloudpickle.dumps(x, protocol=pickle.HIGHEST_PROTOCOL)
return extended_dumps(x, protocol=pickle.HIGHEST_PROTOCOL)
except Exception:
logger.info("Failed to serialize %s", x, exc_info=True)
raise
Expand Down
54 changes: 54 additions & 0 deletions distributed/protocol/tests/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,22 @@
from operator import add
from functools import partial

from tornado import gen, ioloop

from distributed.utils_test import gen_test


@gen.coroutine
def coro():
yield gen.moment


class CoroObject(object):
@gen.coroutine
def f(self, x):
yield gen.moment
raise gen.Return(x + 1)


def test_pickle_data():
data = [1, b'123', '123', [123], {}, set()]
Expand All @@ -28,3 +44,41 @@ def f(x): # closure

for func in [f, lambda x: x + 1, partial(add, 1)]:
assert loads(dumps(func))(1) == func(1)


def test_global_coroutine():
data = dumps(coro)
assert loads(data) is coro
# Should be tiny
assert len(data) < 80


@gen_test()
def test_local_coroutine():
@gen.coroutine
def f(x, y):
yield gen.sleep(x)
raise gen.Return(y + 1)

@gen.coroutine
def g(y):
res = yield f(0.01, y)
raise gen.Return(res + 1)

data = dumps([g, g])
f = g = None
g2, g3 = loads(data)
assert g2 is g3
res = yield g2(5)
assert res == 7


@gen_test()
def test_coroutine_method():
obj = CoroObject()
data = dumps(obj.f)
del obj
f2 = loads(data)
loop = ioloop.IOLoop.current()
res = yield f2(5)
assert res == 6