From fb46550849976d5dc354b9b7cca3c45cdb0f7224 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Tue, 15 Nov 2016 12:14:39 +0100 Subject: [PATCH 1/2] Fix #607: make Tornado coroutines serializable As a bonus, serializing coroutines is now much faster and produces a smaller payload. --- distributed/protocol/pickle.py | 65 +++++++++++++++++++++-- distributed/protocol/tests/test_pickle.py | 51 ++++++++++++++++++ 2 files changed, 112 insertions(+), 4 deletions(-) diff --git a/distributed/protocol/pickle.py b/distributed/protocol/pickle.py index f44c5499cba..af1968c9a1e 100644 --- a/distributed/protocol/pickle.py +++ b/distributed/protocol/pickle.py @@ -1,10 +1,14 @@ from __future__ import print_function, division, absolute_import +import inspect +from io import BytesIO import logging import pickle import cloudpickle +from tornado import gen + from ..utils import ignoring logger = logging.getLogger(__file__) @@ -19,6 +23,59 @@ pickle_types = tuple(pickle_types) +@gen.coroutine +def _tornado_coroutine_sample(): + yield + +def is_tornado_coroutine(func): + """ + Return whether *func* is a Tornado coroutine function. + Running coroutines are not supported. + """ + return func.__code__ is _tornado_coroutine_sample.__code__ + +def _rebuild_tornado_coroutine(func): + from tornado import gen + return gen.coroutine(func) + +def _get_wrapped_function(func): + try: + return func.__wrapped__ + except AttributeError: + pass + # On old Pythons, functools.wraps() doesn't set the __wrapped__ + # attribute. Hack around it by inspecting captured variables. + functions = [] + for cell in func.__closure__: + with ignoring(ValueError): + v = cell.cell_contents + if inspect.isfunction(v): + functions.append(v) + if len(functions) != 1: + raise RuntimeError("failed to unwrap Tornado coroutine %s: " + "%d candidates found" % (func, len(functions))) + return functions[0] + + +class ExtendedPickler(cloudpickle.CloudPickler): + """Extended Pickler class with support for Tornado coroutines. + """ + + def save_function_tuple(self, func): + if is_tornado_coroutine(func): + self.save_reduce(_rebuild_tornado_coroutine, + (_get_wrapped_function(func),), + obj=func) + else: + cloudpickle.CloudPickler.save_function_tuple(self, func) + + +def extended_dumps(obj, protocol=2): + with BytesIO() as bio: + ExtendedPickler(bio, protocol).dump(obj) + return bio.getvalue() + + def dumps(x): """ Manage between cloudpickle and pickle @@ -30,17 +87,17 @@ def dumps(x): result = pickle.dumps(x, protocol=pickle.HIGHEST_PROTOCOL) if len(result) < 1000: if b'__main__' in result: - return cloudpickle.dumps(x, protocol=pickle.HIGHEST_PROTOCOL) + return extended_dumps(x, protocol=pickle.HIGHEST_PROTOCOL) else: return result else: if isinstance(x, pickle_types) or b'__main__' not in result: return result else: - return cloudpickle.dumps(x, protocol=pickle.HIGHEST_PROTOCOL) - except: + return extended_dumps(x, protocol=pickle.HIGHEST_PROTOCOL) + except Exception: try: - return cloudpickle.dumps(x, protocol=pickle.HIGHEST_PROTOCOL) + return extended_dumps(x, protocol=pickle.HIGHEST_PROTOCOL) except Exception: logger.info("Failed to serialize %s", x, exc_info=True) raise diff --git a/distributed/protocol/tests/test_pickle.py b/distributed/protocol/tests/test_pickle.py index 789a8b4dedd..3f8e467d9ea 100644 --- a/distributed/protocol/tests/test_pickle.py +++ b/distributed/protocol/tests/test_pickle.py @@ -5,6 +5,20 @@ from operator import add from functools import partial +from tornado import gen, ioloop + + +@gen.coroutine +def coro(): + yield gen.moment + + +class CoroObject(object): + @gen.coroutine + def f(self, x): + yield gen.moment + raise gen.Return(x + 1) + def test_pickle_data(): data = [1, b'123', '123', [123], {}, set()] @@ -28,3 +42,40 @@ def f(x): # closure for func in [f, lambda x: x + 1, partial(add, 1)]: assert loads(dumps(func))(1) == func(1) + + +def test_global_coroutine(): + data = dumps(coro) + assert loads(data) is coro + # Should be tiny + assert len(data) < 80 + + +def test_local_coroutine(): + @gen.coroutine + def f(x, y): + yield gen.sleep(x) + raise gen.Return(y + 1) + + @gen.coroutine + def g(y): + res = yield f(0.01, y) + raise gen.Return(res + 1) + + data = dumps([g, g]) + f = g = None + g2, g3 = loads(data) + assert g2 is g3 + loop = ioloop.IOLoop.current() + res = loop.run_sync(partial(g2, 5)) + assert res == 7 + + +def test_coroutine_method(): + obj = CoroObject() + data = dumps(obj.f) + del obj + f2 = loads(data) + loop = ioloop.IOLoop.current() + res = loop.run_sync(partial(f2, 5)) + assert res == 6 From 7befd5ccddb6f0a99dba8277f016c2cd623980fa Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Tue, 15 Nov 2016 12:48:34 +0100 Subject: [PATCH 2/2] Fix tests by using @gen_test() --- distributed/protocol/tests/test_pickle.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/distributed/protocol/tests/test_pickle.py b/distributed/protocol/tests/test_pickle.py index 3f8e467d9ea..e033f3d06f3 100644 --- a/distributed/protocol/tests/test_pickle.py +++ b/distributed/protocol/tests/test_pickle.py @@ -7,6 +7,8 @@ from tornado import gen, ioloop +from distributed.utils_test import gen_test + @gen.coroutine def coro(): @@ -51,6 +53,7 @@ def test_global_coroutine(): assert len(data) < 80 +@gen_test() def test_local_coroutine(): @gen.coroutine def f(x, y): @@ -66,16 +69,16 @@ def g(y): f = g = None g2, g3 = loads(data) assert g2 is g3 - loop = ioloop.IOLoop.current() - res = loop.run_sync(partial(g2, 5)) + res = yield g2(5) assert res == 7 +@gen_test() def test_coroutine_method(): obj = CoroObject() data = dumps(obj.f) del obj f2 = loads(data) loop = ioloop.IOLoop.current() - res = loop.run_sync(partial(f2, 5)) + res = yield f2(5) assert res == 6