diff --git a/cloudpickle/cloudpickle.py b/cloudpickle/cloudpickle.py index 2d704fc1f..8fb5a8310 100644 --- a/cloudpickle/cloudpickle.py +++ b/cloudpickle/cloudpickle.py @@ -90,6 +90,16 @@ def _builtin_type(name): return getattr(types, name) +import weakref + +_SINGLETONS = dict((o, str(o)) for o in [weakref.ref]) +_SINGLETONS_REVERSE = dict((v, k) for k, v in _SINGLETONS.items()) + + +def _singleton(name): + return _SINGLETONS_REVERSE[name] + + class CloudPickler(Pickler): dispatch = Pickler.dispatch.copy() @@ -109,6 +119,8 @@ def dump(self, obj): if 'recursion' in e.args[0]: msg = """Could not pickle object as excessively deep recursion required.""" raise pickle.PicklingError(msg) + else: + raise def save_memoryview(self, obj): """Fallback to save_string""" @@ -199,7 +211,7 @@ def save_function(self, obj, name=None): # a builtin_function_or_method which comes in as an attribute of some # object (e.g., object.__new__, itertools.chain.from_iterable) will end # up with modname "__main__" and so end up here. But these functions - # have no __code__ attribute in CPython, so the handling for + # have no __code__ attribute in CPython, so the handling for # user-defined functions below will fail. # So we pickle them here using save_reduce; have to do it differently # for different python versions. @@ -354,6 +366,8 @@ def save_global(self, obj, name=None, pack=struct.pack): if obj.__module__ == "__builtin__" or obj.__module__ == "builtins": if obj in _BUILTIN_TYPE_NAMES: return self.save_reduce(_builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj) + if obj in _SINGLETONS: + return self.save_reduce(_singleton, (_SINGLETONS[obj],), obj=obj) if name is None: name = obj.__name__ diff --git a/tests/cloudpickle_test.py b/tests/cloudpickle_test.py index ec19e1fa4..37057826b 100644 --- a/tests/cloudpickle_test.py +++ b/tests/cloudpickle_test.py @@ -289,6 +289,10 @@ def test_find_module(self): def test_Ellipsis(self): self.assertEqual(Ellipsis, pickle_depickle(Ellipsis)) + def test_weakref(self): + import weakref + self.assertEqual(weakref.ref, pickle_depickle(weakref.ref)) + def test_NotImplemented(self): self.assertEqual(NotImplemented, pickle_depickle(NotImplemented))