Skip to content

Commit

Permalink
Refactor get_field_related_model_cls to raise UnregisteredModelError
Browse files Browse the repository at this point in the history
  • Loading branch information
UnknownPlatypus committed Jun 16, 2023
1 parent f6457d1 commit 1ce442c
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 22 deletions.
20 changes: 15 additions & 5 deletions mypy_django_plugin/django/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from mypy.types import AnyType, Instance, TypeOfAny, UnionType
from mypy.types import Type as MypyType

from mypy_django_plugin.exceptions import UnregisteredModelError
from mypy_django_plugin.lib import fullnames, helpers
from mypy_django_plugin.lib.fullnames import WITH_ANNOTATIONS_FULLNAME

Expand Down Expand Up @@ -123,7 +124,14 @@ def get_model_foreign_keys(self, model_cls: Type[Model]) -> Iterator["ForeignKey
if isinstance(field, ForeignKey):
yield field

def get_model_related_fields(self, model_cls: Type[Model]) -> Iterator["RelatedField[Any, Any]"]:
"""Get model forward relations"""
for field in model_cls._meta.get_fields():
if isinstance(field, RelatedField):
yield field

def get_model_relations(self, model_cls: Type[Model]) -> Iterator[ForeignObjectRel]:
"""Get model reverse relations"""
for field in model_cls._meta.get_fields():
if isinstance(field, ForeignObjectRel):
yield field
Expand Down Expand Up @@ -334,9 +342,7 @@ def get_field_get_type(
else:
return helpers.get_private_descriptor_type(field_info, "_pyi_private_get_type", is_nullable=is_nullable)

def get_field_related_model_cls(
self, field: Union["RelatedField[Any, Any]", ForeignObjectRel]
) -> Optional[Type[Model]]:
def get_field_related_model_cls(self, field: Union["RelatedField[Any, Any]", ForeignObjectRel]) -> Type[Model]:
if isinstance(field, RelatedField):
related_model_cls = field.remote_field.model
else:
Expand All @@ -350,11 +356,13 @@ def get_field_related_model_cls(
# same file model
related_model_fullname = field.model.__module__ + "." + related_model_cls
related_model_cls = self.get_model_class_by_fullname(related_model_fullname)
if related_model_cls is None:
raise UnregisteredModelError
else:
try:
related_model_cls = self.apps_registry.get_model(related_model_cls)
except LookupError:
return None
except LookupError as e:
raise UnregisteredModelError from e

return related_model_cls

Expand Down Expand Up @@ -442,6 +450,8 @@ def resolve_lookup_expected_type(self, ctx: MethodContext, model_cls: Type[Model
except FieldError as exc:
ctx.api.fail(exc.args[0], ctx.context)
return AnyType(TypeOfAny.from_error)
except UnregisteredModelError:
return AnyType(TypeOfAny.from_error)

if solved_lookup is None:
return AnyType(TypeOfAny.implementation_artifact)
Expand Down
2 changes: 2 additions & 0 deletions mypy_django_plugin/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class UnregisteredModelError(Exception):
"""The requested model is not registered"""
26 changes: 13 additions & 13 deletions mypy_django_plugin/main.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import itertools
import sys
from functools import partial
from typing import Callable, Dict, List, Optional, Tuple, Type

from django.db.models.fields.related import RelatedField
from mypy.modulefinder import mypy_path
from mypy.nodes import MypyFile, TypeInfo
from mypy.options import Options
Expand All @@ -20,6 +20,7 @@
import mypy_django_plugin.transformers.orm_lookups
from mypy_django_plugin.config import DjangoPluginConfig
from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.exceptions import UnregisteredModelError
from mypy_django_plugin.lib import fullnames, helpers
from mypy_django_plugin.transformers import fields, forms, init_create, meta, querysets, request, settings
from mypy_django_plugin.transformers.functional import resolve_str_promise_attribute
Expand Down Expand Up @@ -147,23 +148,22 @@ def get_additional_deps(self, file: MypyFile) -> List[Tuple[int, str, int]]:
if not defined_model_classes:
return []
deps = set()

for model_class in defined_model_classes:
# forward relations
for field in self.django_context.get_model_fields(model_class):
if isinstance(field, RelatedField):
for field in itertools.chain(
# forward relations
self.django_context.get_model_related_fields(model_class),
# reverse relations - `related_objects` is private API (according to docstring)
model_class._meta.related_objects, # type: ignore[attr-defined]
):
try:
related_model_cls = self.django_context.get_field_related_model_cls(field)
if related_model_cls is None:
continue
related_model_module = related_model_cls.__module__
if related_model_module != file.fullname:
deps.add(self._new_dependency(related_model_module))
# reverse relations
# `related_objects` is private API (according to docstring)
for relation in model_class._meta.related_objects: # type: ignore[attr-defined]
related_model_cls = self.django_context.get_field_related_model_cls(relation)
except UnregisteredModelError:
continue
related_model_module = related_model_cls.__module__
if related_model_module != file.fullname:
deps.add(self._new_dependency(related_model_module))

return list(deps) + [
# for QuerySet.annotate
self._new_dependency("django_stubs_ext"),
Expand Down
6 changes: 4 additions & 2 deletions mypy_django_plugin/transformers/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from mypy.types import Type as MypyType

from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.exceptions import UnregisteredModelError
from mypy_django_plugin.lib import fullnames, helpers

if TYPE_CHECKING:
Expand Down Expand Up @@ -59,8 +60,9 @@ def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context

assert isinstance(current_field, RelatedField)

related_model_cls = django_context.get_field_related_model_cls(current_field)
if related_model_cls is None:
try:
related_model_cls = django_context.get_field_related_model_cls(current_field)
except UnregisteredModelError:
return AnyType(TypeOfAny.from_error)

default_related_field_type = set_descriptor_types_for_field(ctx)
Expand Down
6 changes: 4 additions & 2 deletions mypy_django_plugin/transformers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from mypy_django_plugin.django.context import DjangoContext
from mypy_django_plugin.errorcodes import MANAGER_MISSING
from mypy_django_plugin.exceptions import UnregisteredModelError
from mypy_django_plugin.lib import fullnames, helpers
from mypy_django_plugin.lib.fullnames import ANNOTATIONS_FULLNAME, ANY_ATTR_ALLOWED_CLASS_FULLNAME, MODEL_CLASS_FULLNAME
from mypy_django_plugin.transformers import fields
Expand Down Expand Up @@ -234,8 +235,9 @@ def run_with_model_cls(self, model_cls: Type[Model]) -> None:
class AddRelatedModelsId(ModelClassInitializer):
def run_with_model_cls(self, model_cls: Type[Model]) -> None:
for field in self.django_context.get_model_foreign_keys(model_cls):
related_model_cls = self.django_context.get_field_related_model_cls(field)
if related_model_cls is None:
try:
related_model_cls = self.django_context.get_field_related_model_cls(field)
except UnregisteredModelError:
error_context: Context = self.ctx.cls
field_sym = self.ctx.cls.info.get(field.name)
if field_sym is not None and field_sym.node is not None:
Expand Down

0 comments on commit 1ce442c

Please sign in to comment.