diff --git a/distributed/protocol/pickle.py b/distributed/protocol/pickle.py index f44c5499cba..af1968c9a1e 100644 --- a/distributed/protocol/pickle.py +++ b/distributed/protocol/pickle.py @@ -1,10 +1,14 @@ from __future__ import print_function, division, absolute_import +import inspect +from io import BytesIO import logging import pickle import cloudpickle +from tornado import gen + from ..utils import ignoring logger = logging.getLogger(__file__) @@ -19,6 +23,59 @@ pickle_types = tuple(pickle_types) +@gen.coroutine +def _tornado_coroutine_sample(): + yield + +def is_tornado_coroutine(func): + """ + Return whether *func* is a Tornado coroutine function. + Running coroutines are not supported. + """ + return func.__code__ is _tornado_coroutine_sample.__code__ + +def _rebuild_tornado_coroutine(func): + from tornado import gen + return gen.coroutine(func) + +def _get_wrapped_function(func): + try: + return func.__wrapped__ + except AttributeError: + pass + # On old Pythons, functools.wraps() doesn't set the __wrapped__ + # attribute. Hack around it by inspecting captured variables. + functions = [] + for cell in func.__closure__: + with ignoring(ValueError): + v = cell.cell_contents + if inspect.isfunction(v): + functions.append(v) + if len(functions) != 1: + raise RuntimeError("failed to unwrap Tornado coroutine %s: " + "%d candidates found" % (func, len(functions))) + return functions[0] + + +class ExtendedPickler(cloudpickle.CloudPickler): + """Extended Pickler class with support for Tornado coroutines. + """ + + def save_function_tuple(self, func): + if is_tornado_coroutine(func): + self.save_reduce(_rebuild_tornado_coroutine, + (_get_wrapped_function(func),), + obj=func) + else: + cloudpickle.CloudPickler.save_function_tuple(self, func) + + +def extended_dumps(obj, protocol=2): + with BytesIO() as bio: + ExtendedPickler(bio, protocol).dump(obj) + return bio.getvalue() + + def dumps(x): """ Manage between cloudpickle and pickle @@ -30,17 +87,17 @@ def dumps(x): result = pickle.dumps(x, protocol=pickle.HIGHEST_PROTOCOL) if len(result) < 1000: if b'__main__' in result: - return cloudpickle.dumps(x, protocol=pickle.HIGHEST_PROTOCOL) + return extended_dumps(x, protocol=pickle.HIGHEST_PROTOCOL) else: return result else: if isinstance(x, pickle_types) or b'__main__' not in result: return result else: - return cloudpickle.dumps(x, protocol=pickle.HIGHEST_PROTOCOL) - except: + return extended_dumps(x, protocol=pickle.HIGHEST_PROTOCOL) + except Exception: try: - return cloudpickle.dumps(x, protocol=pickle.HIGHEST_PROTOCOL) + return extended_dumps(x, protocol=pickle.HIGHEST_PROTOCOL) except Exception: logger.info("Failed to serialize %s", x, exc_info=True) raise diff --git a/distributed/protocol/tests/test_pickle.py b/distributed/protocol/tests/test_pickle.py index 789a8b4dedd..e033f3d06f3 100644 --- a/distributed/protocol/tests/test_pickle.py +++ b/distributed/protocol/tests/test_pickle.py @@ -5,6 +5,22 @@ from operator import add from functools import partial +from tornado import gen, ioloop + +from distributed.utils_test import gen_test + + +@gen.coroutine +def coro(): + yield gen.moment + + +class CoroObject(object): + @gen.coroutine + def f(self, x): + yield gen.moment + raise gen.Return(x + 1) + def test_pickle_data(): data = [1, b'123', '123', [123], {}, set()] @@ -28,3 +44,41 @@ def f(x): # closure for func in [f, lambda x: x + 1, partial(add, 1)]: assert loads(dumps(func))(1) == func(1) + + +def test_global_coroutine(): + data = dumps(coro) + assert loads(data) is coro + # Should be tiny + assert len(data) < 80 + + +@gen_test() +def test_local_coroutine(): + @gen.coroutine + def f(x, y): + yield gen.sleep(x) + raise gen.Return(y + 1) + + @gen.coroutine + def g(y): + res = yield f(0.01, y) + raise gen.Return(res + 1) + + data = dumps([g, g]) + f = g = None + g2, g3 = loads(data) + assert g2 is g3 + res = yield g2(5) + assert res == 7 + + +@gen_test() +def test_coroutine_method(): + obj = CoroObject() + data = dumps(obj.f) + del obj + f2 = loads(data) + loop = ioloop.IOLoop.current() + res = yield f2(5) + assert res == 6