Skip to content

Commit

Permalink
Improve inference for generic classes (PEP 695) (#2433) (#2444)
Browse files Browse the repository at this point in the history
(cherry picked from commit fadac92)

Co-authored-by: Marc Mueller <[email protected]>
  • Loading branch information
github-actions[bot] and cdce8p authored May 19, 2024
1 parent 0ce116d commit 306164e
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 8 deletions.
4 changes: 4 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ What's New in astroid 3.2.2?
Release date: TBA


* Improve inference for generic classes using the PEP 695 syntax (Python 3.12).

Closes pylint-dev/#9406


What's New in astroid 3.2.1?
============================
Expand Down
19 changes: 19 additions & 0 deletions astroid/brain/brain_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,20 @@ def infer_typing_attr(
return node.infer(context=ctx)


def _looks_like_generic_class_pep695(node: ClassDef) -> bool:
"""Check if class is using type parameter. Python 3.12+."""
return len(node.type_params) > 0


def infer_typing_generic_class_pep695(
node: ClassDef, ctx: context.InferenceContext | None = None
) -> Iterator[ClassDef]:
"""Add __class_getitem__ for generic classes. Python 3.12+."""
func_to_add = _extract_single_node(CLASS_GETITEM_TEMPLATE)
node.locals["__class_getitem__"] = [func_to_add]
return iter([node])


def _looks_like_typedDict( # pylint: disable=invalid-name
node: FunctionDef | ClassDef,
) -> bool:
Expand Down Expand Up @@ -490,3 +504,8 @@ def register(manager: AstroidManager) -> None:

if PY312_PLUS:
register_module_extender(manager, "typing", _typing_transform)
manager.register_transform(
ClassDef,
inference_tip(infer_typing_generic_class_pep695),
_looks_like_generic_class_pep695,
)
5 changes: 4 additions & 1 deletion astroid/nodes/scoped_nodes/scoped_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2194,7 +2194,10 @@ def scope_lookup(
and name in AstroidManager().builtins_module
)
if (
any(node == base or base.parent_of(node) for base in self.bases)
any(
node == base or base.parent_of(node) and not self.type_params
for base in self.bases
)
or lookup_upper_frame
):
# Handle the case where we have either a name
Expand Down
7 changes: 3 additions & 4 deletions astroid/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,8 +924,7 @@ def generic_type_assigned_stmts(
context: InferenceContext | None = None,
assign_path: None = None,
) -> Generator[nodes.NodeNG, None, None]:
"""Return empty generator (return -> raises StopIteration) so inferred value
is Uninferable.
"""Hack. Return any Node so inference doesn't fail
when evaluating __class_getitem__. Revert if it's causing issues.
"""
return
yield
yield nodes.Const(None)
18 changes: 18 additions & 0 deletions tests/brain/test_brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,24 @@ def test_typing_generic_subscriptable(self):
assert isinstance(inferred, nodes.ClassDef)
assert isinstance(inferred.getattr("__class_getitem__")[0], nodes.FunctionDef)

@test_utils.require_version(minver="3.12")
def test_typing_generic_subscriptable_pep695(self):
"""Test class using type parameters is subscriptable with __class_getitem__ (added in PY312)"""
node = builder.extract_node(
"""
class Foo[T]: ...
class Bar[T](Foo[T]): ...
"""
)
inferred = next(node.infer())
assert isinstance(inferred, nodes.ClassDef)
assert inferred.name == "Bar"
assert isinstance(inferred.getattr("__class_getitem__")[0], nodes.FunctionDef)
ancestors = list(inferred.ancestors())
assert len(ancestors) == 2
assert ancestors[0].name == "Foo"
assert ancestors[1].name == "object"

@test_utils.require_version(minver="3.9")
def test_typing_annotated_subscriptable(self):
"""Test typing.Annotated is subscriptable with __class_getitem__"""
Expand Down
15 changes: 12 additions & 3 deletions tests/test_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,20 +425,29 @@ def test_assigned_stmts_type_var():
assign_stmts = extract_node("type Point[T] = tuple[float, float]")
type_var: nodes.TypeVar = assign_stmts.type_params[0]
assigned = next(type_var.name.assigned_stmts())
assert assigned is Uninferable
# Hack so inference doesn't fail when evaluating __class_getitem__
# Revert if it's causing issues.
assert isinstance(assigned, nodes.Const)
assert assigned.value is None

@staticmethod
def test_assigned_stmts_type_var_tuple():
"""The result is 'Uninferable' and no exception is raised."""
assign_stmts = extract_node("type Alias[*Ts] = tuple[*Ts]")
type_var_tuple: nodes.TypeVarTuple = assign_stmts.type_params[0]
assigned = next(type_var_tuple.name.assigned_stmts())
assert assigned is Uninferable
# Hack so inference doesn't fail when evaluating __class_getitem__
# Revert if it's causing issues.
assert isinstance(assigned, nodes.Const)
assert assigned.value is None

@staticmethod
def test_assigned_stmts_param_spec():
"""The result is 'Uninferable' and no exception is raised."""
assign_stmts = extract_node("type Alias[**P] = Callable[P, int]")
param_spec: nodes.ParamSpec = assign_stmts.type_params[0]
assigned = next(param_spec.name.assigned_stmts())
assert assigned is Uninferable
# Hack so inference doesn't fail when evaluating __class_getitem__
# Revert if it's causing issues.
assert isinstance(assigned, nodes.Const)
assert assigned.value is None

0 comments on commit 306164e

Please sign in to comment.