Skip to content
20 changes: 18 additions & 2 deletions cloudpickle/cloudpickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,8 +644,9 @@ def save_dynamic_class(self, obj):
else:
# "Regular" class definition:
tp = type(obj)
bases = _get_bases(obj)
self.save_reduce(_make_skeleton_class,
(tp, obj.__name__, obj.__bases__, type_kwargs,
(tp, obj.__name__, bases, type_kwargs,
_ensure_tracking(obj), None),
obj=obj)

Expand Down Expand Up @@ -1163,10 +1164,17 @@ class id will also reuse this class definition.
The "extra" variable is meant to be a dict (or None) that can be used for
forward compatibility shall the need arise.
"""
skeleton_class = type_constructor(name, bases, type_kwargs)
skeleton_class = _make_new_class(type_constructor, name, bases, type_kwargs)
return _lookup_class_or_track(class_tracker_id, skeleton_class)


def _make_new_class(type_constructor, name, bases, type_kwargs):
return types.new_class(
name, bases, {'metaclass': type_constructor},
lambda ns: ns.update(type_kwargs)
)


def _rehydrate_skeleton_class(skeleton_class, class_dict):
"""Put attributes from `class_dict` back on `skeleton_class`.

Expand Down Expand Up @@ -1268,3 +1276,11 @@ def _typevar_reduce(obj):
if module_and_name is None:
return (_make_typevar, _decompose_typevar(obj))
return (getattr, module_and_name)


def _get_bases(typ):
if hasattr(typ, '__orig_bases__'):
bases_attr = '__orig_bases__'
else:
bases_attr = '__bases__'
return getattr(typ, bases_attr)
4 changes: 2 additions & 2 deletions cloudpickle/cloudpickle_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
_is_dynamic, _extract_code_globals, _BUILTIN_TYPE_NAMES, DEFAULT_PROTOCOL,
_find_imported_submodules, _get_cell_contents, _is_importable_by_name, _builtin_type,
Enum, _ensure_tracking, _make_skeleton_class, _make_skeleton_enum,
_extract_class_dict, dynamic_subimport, subimport, _typevar_reduce,
_extract_class_dict, dynamic_subimport, subimport, _typevar_reduce, _get_bases,
)

load, loads = _pickle.load, _pickle.loads
Expand Down Expand Up @@ -76,7 +76,7 @@ def _class_getnewargs(obj):
if isinstance(__dict__, property):
type_kwargs['__dict__'] = __dict__

return (type(obj), obj.__name__, obj.__bases__, type_kwargs,
return (type(obj), obj.__name__, _get_bases(obj), type_kwargs,
_ensure_tracking(obj), None)


Expand Down
47 changes: 47 additions & 0 deletions tests/cloudpickle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2130,6 +2130,53 @@ def test_pickle_importable_typevar(self):
from typing import AnyStr
assert AnyStr is pickle_depickle(AnyStr, protocol=self.protocol)

@unittest.skipIf(sys.version_info < (3, 7),
"Pickling generics not supported below py37")
def test_generic(self):
from typing import (
Optional, TypeVar, Generic, Tuple, Callable,
Dict, Any, ClassVar, NoReturn, Union, List,
)

T = TypeVar('T')

class C(Generic[T]):
pass

objs = [
C, C[int],
T, Any, NoReturn, Optional, Generic,
Union, ClassVar,
Optional[int],
Generic[T],
Callable[[int], Any],
Callable[..., Any],
Callable[[], Any],
Tuple[int, ...],
Tuple[int, C[int]],
ClassVar[C[int]],
List[int],
Dict[int, str],
]

for obj in objs:
_ = pickle_depickle(obj, protocol=self.protocol)

@unittest.skipIf(sys.version_info < (3, 7),
"Pickling generics not supported below py37")
def test_generic_extensions(self):
typing_extensions = pytest.importorskip('typing_extensions')

objs = [
typing_extensions.Literal,
typing_extensions.Final,
typing_extensions.Literal['a'],
typing_extensions.Final[int],
]

for obj in objs:
_ = pickle_depickle(obj, protocol=self.protocol)


class Protocol2CloudPickleTest(CloudPickleTest):

Expand Down