Skip to content

Commit

Permalink
feat: type_unparametrized (#133)
Browse files Browse the repository at this point in the history
* feat: type_unparametrized

For getting the un-parametrized type of an object. This is useful for doing `type(obj)(*args, **kwargs)`

Signed-off-by: nstarman <[email protected]>

* fix py3.8 compat

Signed-off-by: nstarman <[email protected]>

* fix docstring

Signed-off-by: nstarman <[email protected]>

* run pre-commit

Signed-off-by: nstarman <[email protected]>

---------

Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman authored Apr 20, 2024
1 parent c372312 commit 300487f
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 3 deletions.
116 changes: 114 additions & 2 deletions plum/parametric.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union
from typing import Type, TypeVar, Union

import beartype.door
from beartype.roar import BeartypeDoorNonpepException
Expand All @@ -13,11 +13,14 @@
"CovariantMeta",
"parametric",
"type_parameter",
"type_unparametrized",
"kind",
"Kind",
"Val",
]

T = TypeVar("T")


_dispatch = Dispatcher()

Expand Down Expand Up @@ -274,11 +277,82 @@ def class_new(cls, *args, **kw_args):
cls.__new__ = class_new
super(original_class, cls).__init_subclass__(**kw_args)

def __class_nonparametric__(cls):
"""Return the non-parametric type of an object.
:mod:`plum.parametric` produces parametric subtypes of classes. This
method can be used to get the non-parametric type of an object.
Examples
--------
>>> from plum import parametric
>>> @parametric
... class Obj:
... @classmethod
... def __infer_type_parameter__(cls, *arg):
... return type(arg[0])
... def __init__(self, x):
... self.x = x
... def __repr__(self):
... return f"Obj({self.x})"
>>> obj = Obj(1)
>>> obj
Obj(1)
>>> type(obj).mro()
[Obj[int], Obj, object]
>>> obj.__class_nonparametric__().mro()
[Obj, object]
"""
return original_class

def __class_unparametrized__(cls):
"""Return the unparametrized type of an object.
:mod:`plum.parametric` produces parametric subtypes of classes. This
method can be used to get the un-parametrized type of an object.
Examples
--------
>>> from plum import parametric
>>> @parametric
... class Obj:
... @classmethod
... def __infer_type_parameter__(cls, *arg):
... return type(arg[0])
... def __init__(self, x):
... self.x = x
... def __repr__(self):
... return f"Obj({self.x})"
>>> obj = Obj(1)
>>> obj
Obj(1)
>>> type(obj).__name__
Obj[int]
>>> obj.__class_unparametrized__().mro()
[Obj, Obj, object]
Note that this is still NOT the 'original' non-`parametric`-wrapped
type. This is the type that is wrapped by :mod:`plum.parametric`, but
without the inferred type parameter(s).
"""
return parametric_class

# Create parametric class.
parametric_class = meta(
original_class.__name__,
(original_class,),
{"__new__": __new__, "__init_subclass__": __init_subclass__},
{
"__new__": __new__,
"__init_subclass__": __init_subclass__,
"__class_nonparametric__": __class_nonparametric__,
"__class_unparametrized__": __class_unparametrized__,
},
)
parametric_class._parametric = True
parametric_class._concrete = False
Expand Down Expand Up @@ -356,6 +430,44 @@ def type_parameter(x):
)


def type_unparametrized(q: T) -> Type[T]:
"""Return the unparametrized type of an object.
:mod:`plum.parametric` produces parametric subtypes of classes. This
function can be used to get the un-parametrized type of an object.
This function also works for normal, :mod:`plum.parametric`-wrapped classes.
Examples
--------
>>> from plum import parametric
>>> @parametric
... class Obj:
... @classmethod
... def __infer_type_parameter__(cls, *arg):
... return type(arg[0])
... def __init__(self, x):
... self.x = x
... def __repr__(self):
... return f"Obj({self.x})"
>>> obj = Obj(1)
>>> obj
Obj(1)
>>> type(obj).__name__
Obj[int]
>>> type_unparametrized(obj).__name__
Obj
Note that this is still NOT the 'original' non-`parametric`-wrapped type.
This is the type that is wrapped by :mod:`plum.parametric`, but without the
inferred type parameter(s).
"""
typ = type(q)
return q.__class_unparametrized__() if isinstance(typ, ParametricTypeMeta) else typ


def kind(SuperClass=object):
"""Create a parametric wrapper type for dispatch purposes.
Expand Down
33 changes: 32 additions & 1 deletion tests/test_parametric.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
parametric,
type_parameter,
)
from plum.parametric import CovariantMeta, is_concrete, is_type
from plum.parametric import CovariantMeta, is_concrete, is_type, type_unparametrized


def test_covariantmeta():
Expand Down Expand Up @@ -60,6 +60,15 @@ class A(Base1, metaclass=metaclass):
assert issubclass(type(a1), Base1)
assert not issubclass(type(a1), Base2)

assert a1.__class_unparametrized__() is A
assert a2.__class_unparametrized__() is A

# Here we are testing that the class returned by `__class_nonparametric__`
# is the 'original' class that the @parametric decorator was applied to.
assert a1.__class_nonparametric__() is A.mro()[1]
assert issubclass(a2.__class_nonparametric__(), Base1)
assert a2.__class_nonparametric__() is not Base1

# Test multiple type parameters.
assert A[1, 2] == A[1, 2]

Expand Down Expand Up @@ -575,3 +584,25 @@ class Wrapper(Pytree):

Wrapper[int]
assert Wrapper[int] in register


def test_type_unparametrized():
"""Test the `type_unparametrized` function."""

@parametric
class Obj:
@classmethod
def __infer_type_parameter__(cls, *arg):
return type(arg[0])

def __init__(self, x):
self.x = x

def __repr__(self):
return f"Obj({self.x})"

pobj = Obj(1)

assert type(pobj) is Obj[int]
assert type_unparametrized(pobj) is not Obj[int]
assert type_unparametrized(pobj) is Obj

0 comments on commit 300487f

Please sign in to comment.