Skip to content

Commit

Permalink
Remove QuerySet alias hacks via PEP 696 TypeVar defaults (#2104)
Browse files Browse the repository at this point in the history
The `QuerySet` class was previously named `_QuerySet` and had three aliases: `QuerySet`, `QuerySetAny` and `ValuesQuerySet`.

These hacks were mainly needed to for the ergonomic single-parameter `QuerySet[Model]`, which expanded into `_QuerySet[Model, Model]`

But now that mypy 1.10 implements PEP 696 to a fuller extent (Pyright also supports it), the 2nd type parameter can be a simple TypeVar that defaults to 1st type parameter.
  • Loading branch information
intgr authored May 6, 2024
1 parent b0858a7 commit 4a5b065
Show file tree
Hide file tree
Showing 24 changed files with 124 additions and 133 deletions.
25 changes: 0 additions & 25 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -254,31 +254,6 @@ func(MyModel.objects.annotate(foo=Value("")).get(id=1)) # OK
func(MyModel.objects.annotate(bar=Value("")).get(id=1)) # Error
```

### How do I check if something is an instance of QuerySet in runtime?

A limitation of making `QuerySet` generic is that you can not use
it for `isinstance` checks.

```python
from django.db.models.query import QuerySet

def foo(obj: object) -> None:
if isinstance(obj, QuerySet): # Error: Parameterized generics cannot be used with class or instance checks
...
```

To get around with this issue without making `QuerySet` non-generic,
Django-stubs provides `django_stubs_ext.QuerySetAny`, a non-generic
variant of `QuerySet` suitable for runtime type checking:

```python
from django_stubs_ext import QuerySetAny

def foo(obj: object) -> None:
if isinstance(obj, QuerySetAny): # OK
...
```

### Why am I getting incompatible argument type mentioning `_StrPromise`?

The lazy translation functions of Django (such as `gettext_lazy`) return a `Promise` instead of `str`. These two types [cannot be used interchangeably](https://github.com/typeddjango/django-stubs/pull/1139#issuecomment-1232167698). The return type of these functions was therefore [changed](https://github.com/typeddjango/django-stubs/pull/689) to reflect that.
Expand Down
12 changes: 4 additions & 8 deletions django-stubs/db/models/manager.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ from django.db.models.expressions import Combinable, OrderBy
from django.db.models.query import QuerySet, RawQuerySet
from typing_extensions import Self

from django_stubs_ext import ValuesQuerySet

_T = TypeVar("_T", bound=Model, covariant=True)

class BaseManager(Generic[_T]):
Expand Down Expand Up @@ -107,15 +105,13 @@ class BaseManager(Generic[_T]):
using: str | None = ...,
) -> RawQuerySet: ...
# The type of values may be overridden to be more specific in the mypy plugin, depending on the fields param
def values(self, *fields: str | Combinable, **expressions: Any) -> ValuesQuerySet[_T, dict[str, Any]]: ...
def values(self, *fields: str | Combinable, **expressions: Any) -> QuerySet[_T, dict[str, Any]]: ...
# The type of values_list may be overridden to be more specific in the mypy plugin, depending on the fields param
def values_list(
self, *fields: str | Combinable, flat: bool = ..., named: bool = ...
) -> ValuesQuerySet[_T, Any]: ...
def dates(self, field_name: str, kind: str, order: str = ...) -> ValuesQuerySet[_T, datetime.date]: ...
def values_list(self, *fields: str | Combinable, flat: bool = ..., named: bool = ...) -> QuerySet[_T, Any]: ...
def dates(self, field_name: str, kind: str, order: str = ...) -> QuerySet[_T, datetime.date]: ...
def datetimes(
self, field_name: str, kind: str, order: str = ..., tzinfo: datetime.tzinfo | None = ...
) -> ValuesQuerySet[_T, datetime.datetime]: ...
) -> QuerySet[_T, datetime.datetime]: ...
def none(self) -> QuerySet[_T]: ...
def all(self) -> QuerySet[_T]: ...
def filter(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
Expand Down
96 changes: 49 additions & 47 deletions django-stubs/db/models/query.pyi
Original file line number Diff line number Diff line change
@@ -1,32 +1,33 @@
import datetime
from collections.abc import AsyncIterator, Collection, Iterable, Iterator, MutableMapping, Sequence, Sized
from typing import Any, Generic, NamedTuple, TypeVar, overload
from typing import Any, Generic, NamedTuple, overload

from django.db.backends.utils import _ExecuteQuery
from django.db.models import Manager
from django.db.models.base import Model
from django.db.models.expressions import Combinable, OrderBy
from django.db.models.sql.query import Query, RawQuery
from django.utils.functional import cached_property
from typing_extensions import Self, TypeAlias
from typing_extensions import Self, TypeAlias, TypeVar

_T = TypeVar("_T", bound=Model, covariant=True)
_Row = TypeVar("_Row", covariant=True)
_T = TypeVar("_T", covariant=True)
_Model = TypeVar("_Model", bound=Model, covariant=True)
_Row = TypeVar("_Row", covariant=True, default=_Model) # ONLY use together with _Model
_QS = TypeVar("_QS", bound=_QuerySet)
_TupleT = TypeVar("_TupleT", bound=tuple[Any, ...], covariant=True)

MAX_GET_RESULTS: int
REPR_OUTPUT_SIZE: int

class BaseIterable(Generic[_Row]):
class BaseIterable(Generic[_T]):
queryset: QuerySet[Model]
chunked_fetch: bool
chunk_size: int
def __init__(self, queryset: QuerySet[Model], chunked_fetch: bool = ..., chunk_size: int = ...) -> None: ...
def __aiter__(self) -> AsyncIterator[_Row]: ...
def __aiter__(self) -> AsyncIterator[_T]: ...

class ModelIterable(Generic[_T], BaseIterable[_T]):
def __iter__(self) -> Iterator[_T]: ...
class ModelIterable(Generic[_Model], BaseIterable[_Model]):
def __iter__(self) -> Iterator[_Model]: ...

class RawModelIterable(BaseIterable[dict[str, Any]]):
def __iter__(self) -> Iterator[dict[str, Any]]: ...
Expand All @@ -40,11 +41,11 @@ class ValuesListIterable(BaseIterable[_TupleT]):
class NamedValuesListIterable(ValuesListIterable[NamedTuple]):
def __iter__(self) -> Iterator[NamedTuple]: ...

class FlatValuesListIterable(BaseIterable[_Row]):
def __iter__(self) -> Iterator[_Row]: ...
class FlatValuesListIterable(BaseIterable[_T]):
def __iter__(self) -> Iterator[_T]: ...

class _QuerySet(Generic[_T, _Row], Iterable[_Row], Sized):
model: type[_T]
class QuerySet(Generic[_Model, _Row], Iterable[_Row], Sized):
model: type[_Model]
query: Query
_iterable_class: type[BaseIterable]
_result_cache: list[_Row] | None
Expand All @@ -56,14 +57,14 @@ class _QuerySet(Generic[_T, _Row], Iterable[_Row], Sized):
hints: dict[str, Model] | None = ...,
) -> None: ...
@classmethod
def as_manager(cls) -> Manager[_T]: ...
def as_manager(cls) -> Manager[_Model]: ...
def __len__(self) -> int: ...
def __bool__(self) -> bool: ...
def __class_getitem__(cls: type[_QS], item: type[_T]) -> type[_QS]: ...
def __class_getitem__(cls: type[_QS], item: type[_Model]) -> type[_QS]: ...
def __getstate__(self) -> dict[str, Any]: ...
# Technically, the other QuerySet must be of the same type _T, but _T is covariant
def __and__(self, other: _QuerySet[_T, _Row]) -> Self: ...
def __or__(self, other: _QuerySet[_T, _Row]) -> Self: ...
def __and__(self, other: QuerySet[_Model, _Row]) -> Self: ...
def __or__(self, other: QuerySet[_Model, _Row]) -> Self: ...
# IMPORTANT: When updating any of the following methods' signatures, please ALSO modify
# the corresponding method in BaseManager.
def iterator(self, chunk_size: int | None = ...) -> Iterator[_Row]: ...
Expand All @@ -72,44 +73,46 @@ class _QuerySet(Generic[_T, _Row], Iterable[_Row], Sized):
async def aaggregate(self, *args: Any, **kwargs: Any) -> dict[str, Any]: ...
def get(self, *args: Any, **kwargs: Any) -> _Row: ...
async def aget(self, *args: Any, **kwargs: Any) -> _Row: ...
def create(self, **kwargs: Any) -> _T: ...
async def acreate(self, **kwargs: Any) -> _T: ...
def create(self, **kwargs: Any) -> _Model: ...
async def acreate(self, **kwargs: Any) -> _Model: ...
def bulk_create(
self,
objs: Iterable[_T],
objs: Iterable[_Model],
batch_size: int | None = ...,
ignore_conflicts: bool = ...,
update_conflicts: bool = ...,
update_fields: Collection[str] | None = ...,
unique_fields: Collection[str] | None = ...,
) -> list[_T]: ...
) -> list[_Model]: ...
async def abulk_create(
self,
objs: Iterable[_T],
objs: Iterable[_Model],
batch_size: int | None = ...,
ignore_conflicts: bool = ...,
update_conflicts: bool = ...,
update_fields: Collection[str] | None = ...,
unique_fields: Collection[str] | None = ...,
) -> list[_T]: ...
def bulk_update(self, objs: Iterable[_T], fields: Iterable[str], batch_size: int | None = ...) -> int: ...
async def abulk_update(self, objs: Iterable[_T], fields: Iterable[str], batch_size: int | None = ...) -> int: ...
def get_or_create(self, defaults: MutableMapping[str, Any] | None = ..., **kwargs: Any) -> tuple[_T, bool]: ...
) -> list[_Model]: ...
def bulk_update(self, objs: Iterable[_Model], fields: Iterable[str], batch_size: int | None = ...) -> int: ...
async def abulk_update(
self, objs: Iterable[_Model], fields: Iterable[str], batch_size: int | None = ...
) -> int: ...
def get_or_create(self, defaults: MutableMapping[str, Any] | None = ..., **kwargs: Any) -> tuple[_Model, bool]: ...
async def aget_or_create(
self, defaults: MutableMapping[str, Any] | None = ..., **kwargs: Any
) -> tuple[_T, bool]: ...
) -> tuple[_Model, bool]: ...
def update_or_create(
self,
defaults: MutableMapping[str, Any] | None = ...,
create_defaults: MutableMapping[str, Any] | None = ...,
**kwargs: Any,
) -> tuple[_T, bool]: ...
) -> tuple[_Model, bool]: ...
async def aupdate_or_create(
self,
defaults: MutableMapping[str, Any] | None = ...,
create_defaults: MutableMapping[str, Any] | None = ...,
**kwargs: Any,
) -> tuple[_T, bool]: ...
) -> tuple[_Model, bool]: ...
def earliest(self, *fields: str | OrderBy) -> _Row: ...
async def aearliest(self, *fields: str | OrderBy) -> _Row: ...
def latest(self, *fields: str | OrderBy) -> _Row: ...
Expand All @@ -118,8 +121,8 @@ class _QuerySet(Generic[_T, _Row], Iterable[_Row], Sized):
async def afirst(self) -> _Row | None: ...
def last(self) -> _Row | None: ...
async def alast(self) -> _Row | None: ...
def in_bulk(self, id_list: Iterable[Any] | None = ..., *, field_name: str = ...) -> dict[Any, _T]: ...
async def ain_bulk(self, id_list: Iterable[Any] | None = ..., *, field_name: str = ...) -> dict[Any, _T]: ...
def in_bulk(self, id_list: Iterable[Any] | None = ..., *, field_name: str = ...) -> dict[Any, _Model]: ...
async def ain_bulk(self, id_list: Iterable[Any] | None = ..., *, field_name: str = ...) -> dict[Any, _Model]: ...
def delete(self) -> tuple[int, dict[str, int]]: ...
async def adelete(self) -> tuple[int, dict[str, int]]: ...
def update(self, **kwargs: Any) -> int: ...
Expand All @@ -138,13 +141,13 @@ class _QuerySet(Generic[_T, _Row], Iterable[_Row], Sized):
using: str | None = ...,
) -> RawQuerySet: ...
# The type of values may be overridden to be more specific in the mypy plugin, depending on the fields param
def values(self, *fields: str | Combinable, **expressions: Any) -> _QuerySet[_T, dict[str, Any]]: ...
def values(self, *fields: str | Combinable, **expressions: Any) -> QuerySet[_Model, dict[str, Any]]: ...
# The type of values_list may be overridden to be more specific in the mypy plugin, depending on the fields param
def values_list(self, *fields: str | Combinable, flat: bool = ..., named: bool = ...) -> _QuerySet[_T, Any]: ...
def dates(self, field_name: str, kind: str, order: str = ...) -> _QuerySet[_T, datetime.date]: ...
def values_list(self, *fields: str | Combinable, flat: bool = ..., named: bool = ...) -> QuerySet[_Model, Any]: ...
def dates(self, field_name: str, kind: str, order: str = ...) -> QuerySet[_Model, datetime.date]: ...
def datetimes(
self, field_name: str, kind: str, order: str = ..., tzinfo: datetime.tzinfo | None = ...
) -> _QuerySet[_T, datetime.datetime]: ...
) -> QuerySet[_Model, datetime.datetime]: ...
def none(self) -> Self: ...
def all(self) -> Self: ...
def filter(self, *args: Any, **kwargs: Any) -> Self: ...
Expand Down Expand Up @@ -173,7 +176,7 @@ class _QuerySet(Generic[_T, _Row], Iterable[_Row], Sized):
tables: Sequence[str] | None = ...,
order_by: Sequence[str] | None = ...,
select_params: Sequence[Any] | None = ...,
) -> _QuerySet[Any, Any]: ...
) -> QuerySet[Any, Any]: ...
def reverse(self) -> Self: ...
def defer(self, *fields: Any) -> Self: ...
def only(self, *fields: Any) -> Self: ...
Expand All @@ -192,7 +195,7 @@ class _QuerySet(Generic[_T, _Row], Iterable[_Row], Sized):
def __getitem__(self, s: slice) -> Self: ...
def __reversed__(self) -> Iterator[_Row]: ...

class RawQuerySet(Iterable[_T], Sized):
class RawQuerySet(Iterable[_Model], Sized):
query: RawQuery
def __init__(
self,
Expand All @@ -205,28 +208,27 @@ class RawQuerySet(Iterable[_T], Sized):
hints: dict[str, Model] | None = ...,
) -> None: ...
def __len__(self) -> int: ...
def __iter__(self) -> Iterator[_T]: ...
def __iter__(self) -> Iterator[_Model]: ...
def __bool__(self) -> bool: ...
@overload
def __getitem__(self, k: int) -> _T: ...
def __getitem__(self, k: int) -> _Model: ...
@overload
def __getitem__(self, k: str) -> Any: ...
@overload
def __getitem__(self, k: slice) -> RawQuerySet[_T]: ...
def __getitem__(self, k: slice) -> RawQuerySet[_Model]: ...
@cached_property
def columns(self) -> list[str]: ...
@property
def db(self) -> str: ...
def iterator(self) -> Iterator[_T]: ...
def iterator(self) -> Iterator[_Model]: ...
@cached_property
def model_fields(self) -> dict[str, str]: ...
def prefetch_related(self, *lookups: Any) -> RawQuerySet[_T]: ...
def prefetch_related(self, *lookups: Any) -> RawQuerySet[_Model]: ...
def resolve_model_init_order(self) -> tuple[list[str], list[int], list[tuple[str, int]]]: ...
def using(self, alias: str | None) -> RawQuerySet[_T]: ...

_QuerySetAny: TypeAlias = _QuerySet # noqa: PYI047
def using(self, alias: str | None) -> RawQuerySet[_Model]: ...

QuerySet: TypeAlias = _QuerySet[_T, _T]
# Deprecated alias of QuerySet, for compatibility only.
_QuerySet: TypeAlias = QuerySet

class Prefetch:
prefetch_through: str
Expand All @@ -240,8 +242,8 @@ class Prefetch:
def get_current_to_attr(self, level: int) -> tuple[str, str]: ...
def get_current_queryset(self, level: int) -> QuerySet | None: ...

def prefetch_related_objects(model_instances: Iterable[_T], *related_lookups: str | Prefetch) -> None: ...
async def aprefetch_related_objects(model_instances: Iterable[_T], *related_lookups: str | Prefetch) -> None: ...
def prefetch_related_objects(model_instances: Iterable[_Model], *related_lookups: str | Prefetch) -> None: ...
async def aprefetch_related_objects(model_instances: Iterable[_Model], *related_lookups: str | Prefetch) -> None: ...
def get_prefetcher(instance: Model, through_attr: str, to_attr: str) -> tuple[Any, Any, bool, bool]: ...

class InstanceCheckMeta(type): ...
Expand Down
4 changes: 3 additions & 1 deletion ext/django_stubs_ext/aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@
from django.utils.functional import _StrOrPromise as StrOrPromise
from django.utils.functional import _StrPromise as StrPromise

# Deprecated type aliases. Use the QuerySet class directly instead.
QuerySetAny = _QuerySet
ValuesQuerySet = _QuerySet
else:
from django.db.models.query import QuerySet
from django.utils.functional import Promise as StrPromise

StrOrPromise = typing.Union[str, StrPromise]
# Deprecated type aliases. Use the QuerySet class directly instead.
QuerySetAny = QuerySet
ValuesQuerySet = QuerySet
StrOrPromise = typing.Union[str, StrPromise]

__all__ = ["StrOrPromise", "StrPromise", "QuerySetAny", "ValuesQuerySet"]
2 changes: 1 addition & 1 deletion mypy_django_plugin/lib/fullnames.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
DUMMY_SETTINGS_BASE_CLASS = "django.conf._DjangoConfLazyObject"
AUTH_USER_MODEL_FULLNAME = "django.conf.settings.AUTH_USER_MODEL"

QUERYSET_CLASS_FULLNAME = "django.db.models.query._QuerySet"
QUERYSET_CLASS_FULLNAME = "django.db.models.query.QuerySet"
BASE_MANAGER_CLASS_FULLNAME = "django.db.models.manager.BaseManager"
MANAGER_CLASS_FULLNAME = "django.db.models.manager.Manager"
RELATED_MANAGER_CLASS = "django.db.models.fields.related_descriptors.RelatedManager"
Expand Down
2 changes: 1 addition & 1 deletion mypy_django_plugin/transformers/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def example(self, a: T2) -> T_2: ...
return False

if type_info.has_base(fullnames.QUERYSET_CLASS_FULLNAME):
# If it is a subclass of _QuerySet, it is compatible.
# If it is a subclass of QuerySet, it is compatible.
return True
# check that at least one base is a subclass of queryset with Generic type vars
return any(_has_compatible_type_vars(sub_base.type) for sub_base in type_info.bases)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def find_stub_files(name: str) -> List[str]:
"django-stubs-ext>=5.0.0",
"tomli; python_version < '3.11'",
# Types:
"typing-extensions",
"typing-extensions>=4.11.0",
"types-PyYAML",
]

Expand Down
8 changes: 4 additions & 4 deletions tests/typecheck/contrib/admin/test_decorators.yml
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@
def method_action_invalid_fancy(self, request: HttpRequest, queryset: int) -> None: ...
def method(self) -> None:
reveal_type(self.method_action_bare) # N: Revealed type is "def (django.http.request.HttpRequest, django.db.models.query._QuerySet[main.MyModel, main.MyModel])"
reveal_type(self.method_action_fancy) # N: Revealed type is "def (django.http.request.HttpRequest, django.db.models.query._QuerySet[main.MyModel, main.MyModel])"
reveal_type(self.method_action_http_response) # N: Revealed type is "def (django.http.request.HttpRequest, django.db.models.query._QuerySet[main.MyModel, main.MyModel]) -> django.http.response.HttpResponse"
reveal_type(self.method_action_file_response) # N: Revealed type is "def (django.http.request.HttpRequest, django.db.models.query._QuerySet[main.MyModel, main.MyModel]) -> django.http.response.FileResponse"
reveal_type(self.method_action_bare) # N: Revealed type is "def (django.http.request.HttpRequest, django.db.models.query.QuerySet[main.MyModel, main.MyModel])"
reveal_type(self.method_action_fancy) # N: Revealed type is "def (django.http.request.HttpRequest, django.db.models.query.QuerySet[main.MyModel, main.MyModel])"
reveal_type(self.method_action_http_response) # N: Revealed type is "def (django.http.request.HttpRequest, django.db.models.query.QuerySet[main.MyModel, main.MyModel]) -> django.http.response.HttpResponse"
reveal_type(self.method_action_file_response) # N: Revealed type is "def (django.http.request.HttpRequest, django.db.models.query.QuerySet[main.MyModel, main.MyModel]) -> django.http.response.FileResponse"
2 changes: 1 addition & 1 deletion tests/typecheck/contrib/admin/test_options.yml
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@
pass
class A(admin.ModelAdmin):
actions = [an_action] # E: List item 0 has incompatible type "Callable[[None], None]"; expected "Union[Callable[[Any, HttpRequest, _QuerySet[Any, Any]], Optional[HttpResponseBase]], str]" [list-item]
actions = [an_action] # E: List item 0 has incompatible type "Callable[[None], None]"; expected "Union[Callable[[Any, HttpRequest, QuerySet[Any, Any]], Optional[HttpResponseBase]], str]" [list-item]
- case: errors_for_invalid_model_admin_generic
main: |
from django.contrib.admin import ModelAdmin
Expand Down
2 changes: 1 addition & 1 deletion tests/typecheck/contrib/sitemaps/test_generic_sitemap.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
main:26: error: Argument 1 of "location" is incompatible with supertype "Sitemap"; supertype defines the argument type as "Offer" [override]
main:26: note: This violates the Liskov substitution principle
main:26: note: See https://mypy.readthedocs.io/en/stable/common_issues.html#incompatible-overrides
main:40: error: Argument 1 to "GenericSitemap" has incompatible type "Dict[str, List[int]]"; expected "Mapping[str, Union[datetime, _QuerySet[Offer, Offer], str]]" [arg-type]
main:40: error: Argument 1 to "GenericSitemap" has incompatible type "Dict[str, List[int]]"; expected "Mapping[str, Union[datetime, QuerySet[Offer, Offer], str]]" [arg-type]
installed_apps:
- myapp
Expand Down
Loading

0 comments on commit 4a5b065

Please sign in to comment.