Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion cloudpickle/cloudpickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,16 @@ def _builtin_type(name):
return getattr(types, name)


import weakref

_SINGLETONS = dict((o, str(o)) for o in [weakref.ref])
_SINGLETONS_REVERSE = dict((v, k) for k, v in _SINGLETONS.items())


def _singleton(name):
return _SINGLETONS_REVERSE[name]


class CloudPickler(Pickler):

dispatch = Pickler.dispatch.copy()
Expand All @@ -109,6 +119,8 @@ def dump(self, obj):
if 'recursion' in e.args[0]:
msg = """Could not pickle object as excessively deep recursion required."""
raise pickle.PicklingError(msg)
else:
raise

def save_memoryview(self, obj):
"""Fallback to save_string"""
Expand Down Expand Up @@ -199,7 +211,7 @@ def save_function(self, obj, name=None):
# a builtin_function_or_method which comes in as an attribute of some
# object (e.g., object.__new__, itertools.chain.from_iterable) will end
# up with modname "__main__" and so end up here. But these functions
# have no __code__ attribute in CPython, so the handling for
# have no __code__ attribute in CPython, so the handling for
# user-defined functions below will fail.
# So we pickle them here using save_reduce; have to do it differently
# for different python versions.
Expand Down Expand Up @@ -354,6 +366,8 @@ def save_global(self, obj, name=None, pack=struct.pack):
if obj.__module__ == "__builtin__" or obj.__module__ == "builtins":
if obj in _BUILTIN_TYPE_NAMES:
return self.save_reduce(_builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj)
if obj in _SINGLETONS:
return self.save_reduce(_singleton, (_SINGLETONS[obj],), obj=obj)

if name is None:
name = obj.__name__
Expand Down
4 changes: 4 additions & 0 deletions tests/cloudpickle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,10 @@ def test_find_module(self):
def test_Ellipsis(self):
self.assertEqual(Ellipsis, pickle_depickle(Ellipsis))

def test_weakref(self):
import weakref
self.assertEqual(weakref.ref, pickle_depickle(weakref.ref))

def test_NotImplemented(self):
self.assertEqual(NotImplemented, pickle_depickle(NotImplemented))

Expand Down