From 2f287e67f18494ba21795abed2afaaa94ab3fe9e Mon Sep 17 00:00:00 2001 From: anthony sottile <103459774+asottile-sentry@users.noreply.github.com> Date: Mon, 22 Jul 2024 14:16:32 -0400 Subject: [PATCH] use field annotations for values_list types (#2248) (#20) Co-authored-by: Anthony Sottile Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- mypy_django_plugin/django/context.py | 65 +++++++++++++------ mypy_django_plugin/transformers/querysets.py | 13 ++-- .../managers/querysets/test_values_list.yml | 20 ++++++ 3 files changed, 75 insertions(+), 23 deletions(-) diff --git a/mypy_django_plugin/django/context.py b/mypy_django_plugin/django/context.py index 4da4cb29d..3a6af7042 100644 --- a/mypy_django_plugin/django/context.py +++ b/mypy_django_plugin/django/context.py @@ -91,6 +91,35 @@ class LookupsAreUnsupported(Exception): pass +def _get_field_type_from_model_type_info(info: Optional[TypeInfo], field_name: str) -> Optional[Instance]: + if info is None: + return None + field_node = info.get(field_name) + if field_node is None or not isinstance(field_node.type, Instance): + return None + # Field declares a set and a get type arg. Fallback to `None` when we can't find any args + elif len(field_node.type.args) != 2: + return None + else: + return field_node.type + + +def _get_field_set_type_from_model_type_info(info: Optional[TypeInfo], field_name: str) -> Optional[MypyType]: + field_type = _get_field_type_from_model_type_info(info, field_name) + if field_type is not None: + return field_type.args[0] + else: + return None + + +def _get_field_get_type_from_model_type_info(info: Optional[TypeInfo], field_name: str) -> Optional[MypyType]: + field_type = _get_field_type_from_model_type_info(info, field_name) + if field_type is not None: + return field_type.args[1] + else: + return None + + class DjangoContext: def __init__(self, django_settings_module: str) -> None: self.django_settings_module = django_settings_module @@ -152,13 +181,13 @@ def get_field_lookup_exact_type( ) -> MypyType: if isinstance(field, (RelatedField, ForeignObjectRel)): related_model_cls = self.get_field_related_model_cls(field) - primary_key_field = self.get_primary_key_field(related_model_cls) - primary_key_type = self.get_field_get_type(api, primary_key_field, method="init") - rel_model_info = helpers.lookup_class_typeinfo(api, related_model_cls) if rel_model_info is None: return AnyType(TypeOfAny.explicit) + primary_key_field = self.get_primary_key_field(related_model_cls) + primary_key_type = self.get_field_get_type(api, rel_model_info, primary_key_field, method="init") + model_and_primary_key_type = UnionType.make_union([Instance(rel_model_info, []), primary_key_type]) return helpers.make_optional(model_and_primary_key_type) @@ -200,19 +229,6 @@ def get_expected_types(self, api: TypeChecker, model_cls: Type[Model], *, method field_set_type = self.get_field_set_type(api, primary_key_field, method=method) expected_types["pk"] = field_set_type - def get_field_set_type_from_model_type_info(info: Optional[TypeInfo], field_name: str) -> Optional[MypyType]: - if info is None: - return None - field_node = info.get(field_name) - if field_node is None or not isinstance(field_node.type, Instance): - return None - elif not field_node.type.args: - # Field declares a set and a get type arg. Fallback to `None` when we can't find any args - return None - - set_type = field_node.type.args[0] - return set_type - model_info = helpers.lookup_class_typeinfo(api, model_cls) for field in model_cls._meta.get_fields(): if isinstance(field, Field): @@ -223,7 +239,7 @@ def get_field_set_type_from_model_type_info(info: Optional[TypeInfo], field_name # Try to retrieve set type from a model's TypeInfo object and fallback to retrieving it manually # from django-stubs own declaration. This is to align with the setter types declared for # assignment. - field_set_type = get_field_set_type_from_model_type_info( + field_set_type = _get_field_set_type_from_model_type_info( model_info, field_name ) or self.get_field_set_type(api, field, method=method) expected_types[field_name] = field_set_type @@ -340,9 +356,19 @@ def get_field_set_type( return field_set_type def get_field_get_type( - self, api: TypeChecker, field: Union["Field[Any, Any]", ForeignObjectRel], *, method: str + self, + api: TypeChecker, + model_info: Optional[TypeInfo], + field: Union["Field[Any, Any]", ForeignObjectRel], + *, + method: str, ) -> MypyType: """Get a type of __get__ for this specific Django field.""" + if isinstance(field, Field): + get_type = _get_field_get_type_from_model_type_info(model_info, field.attname) + if get_type is not None: + return get_type + field_info = helpers.lookup_class_typeinfo(api, field.__class__) if field_info is None: return AnyType(TypeOfAny.unannotated) @@ -350,10 +376,11 @@ def get_field_get_type( is_nullable = self.get_field_nullability(field, method) if isinstance(field, RelatedField): related_model_cls = self.get_field_related_model_cls(field) + rel_model_info = helpers.lookup_class_typeinfo(api, related_model_cls) if method in ("values", "values_list"): primary_key_field = self.get_primary_key_field(related_model_cls) - return self.get_field_get_type(api, primary_key_field, method=method) + return self.get_field_get_type(api, rel_model_info, primary_key_field, method=method) model_info = helpers.lookup_class_typeinfo(api, related_model_cls) if model_info is None: diff --git a/mypy_django_plugin/transformers/querysets.py b/mypy_django_plugin/transformers/querysets.py index 05b1dad3c..b65206952 100644 --- a/mypy_django_plugin/transformers/querysets.py +++ b/mypy_django_plugin/transformers/querysets.py @@ -66,10 +66,12 @@ def get_field_type_from_lookup( elif (isinstance(lookup_field, RelatedField) and lookup_field.column == lookup) or isinstance( lookup_field, ForeignObjectRel ): - related_model_cls = django_context.get_field_related_model_cls(lookup_field) - lookup_field = django_context.get_primary_key_field(related_model_cls) + model_cls = django_context.get_field_related_model_cls(lookup_field) + lookup_field = django_context.get_primary_key_field(model_cls) - field_get_type = django_context.get_field_get_type(helpers.get_typechecker_api(ctx), lookup_field, method=method) + api = helpers.get_typechecker_api(ctx) + model_info = helpers.lookup_class_typeinfo(api, model_cls) + field_get_type = django_context.get_field_get_type(api, model_info, lookup_field, method=method) return field_get_type @@ -87,6 +89,7 @@ def get_values_list_row_type( return AnyType(TypeOfAny.from_error) typechecker_api = helpers.get_typechecker_api(ctx) + model_info = helpers.lookup_class_typeinfo(typechecker_api, model_cls) if len(field_lookups) == 0: if flat: primary_key_field = django_context.get_primary_key_field(model_cls) @@ -98,7 +101,9 @@ def get_values_list_row_type( elif named: column_types: OrderedDict[str, MypyType] = OrderedDict() for field in django_context.get_model_fields(model_cls): - column_type = django_context.get_field_get_type(typechecker_api, field, method="values_list") + column_type = django_context.get_field_get_type( + typechecker_api, model_info, field, method="values_list" + ) column_types[field.attname] = column_type if is_annotated: # Return a NamedTuple with a fallback so that it's possible to access any field diff --git a/tests/typecheck/managers/querysets/test_values_list.yml b/tests/typecheck/managers/querysets/test_values_list.yml index 622a2d506..b20935b08 100644 --- a/tests/typecheck/managers/querysets/test_values_list.yml +++ b/tests/typecheck/managers/querysets/test_values_list.yml @@ -47,6 +47,26 @@ name = models.CharField(max_length=100) age = models.IntegerField() +- case: values_list_types_are_field_types + main: | + from myapp.models import Concrete + ret = list(Concrete.objects.values_list('id', 'data')) + reveal_type(ret) # N: Revealed type is "builtins.list[Tuple[builtins.int, builtins.dict[builtins.str, builtins.str]]]" + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from __future__ import annotations + from django.db import models + + class JSONField(models.TextField): pass # incomplete + + class Concrete(models.Model): + id = models.IntegerField() + data: models.Field[dict[str, str], dict[str, str]] = JSONField() + - case: values_list_supports_queryset_methods main: | from myapp.models import MyUser