diff --git a/django_choices_field/fields.py b/django_choices_field/fields.py index 9bc3d27..c9b5ac1 100644 --- a/django_choices_field/fields.py +++ b/django_choices_field/fields.py @@ -1,3 +1,5 @@ +from typing import Optional, Type + from django.core.exceptions import ValidationError from django.db import models @@ -8,7 +10,13 @@ class ChoicesField(models.CharField): "invalid": "ā€œ%(value)sā€ must be a subclass of %(enum)s.", } - def __init__(self, verbose_name=None, name=None, choices_enum=None, **kwargs): + def __init__( + self, + choices_enum: Type[models.TextChoices], + verbose_name: Optional[str] = None, + name: Optional[str] = None, + **kwargs, + ): self.choices_enum = choices_enum kwargs["choices"] = choices_enum.choices kwargs.setdefault("max_length", max(len(c.value) for c in choices_enum)) @@ -20,7 +28,6 @@ def deconstruct(self): return name, path, args, kwargs def to_python(self, value): - print(1, value, type(value)) if value is None: return None diff --git a/django_choices_field/fields.pyi b/django_choices_field/fields.pyi index f79485c..7e560ba 100644 --- a/django_choices_field/fields.pyi +++ b/django_choices_field/fields.pyi @@ -4,6 +4,7 @@ from typing import ( Dict, Generic, Iterable, + Literal, Optional, Tuple, Type, @@ -78,7 +79,7 @@ class ChoicesField(Generic[_C], Field[_C, _C]): def __get__(self: ChoicesField[_C], instance: Any, owner: Any) -> _C: ... @overload def __get__(self: ChoicesField[Optional[_C]], instance: Any, owner: Any) -> Optional[_C]: ... - @overload - def __set__(self, instance: ChoicesField[_C], value: _C) -> None: ... - @overload - def __set__(self, instance: ChoicesField[Optional[_C]], value: Optional[_C]) -> None: ... + # @overload + # def __set__(self, instance: ChoicesField[_C], value: _C) -> None: ... + # @overload + # def __set__(self, instance: ChoicesField[Optional[_C]], value: Optional[_C]) -> None: ... diff --git a/pyproject.toml b/pyproject.toml index 1500538..a971404 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,7 +80,7 @@ multi_line_output = 3 force_sort_within_sections = true [tool.pyright] -pythonVersion = "3.7" +pythonVersion = "3.8" useLibraryCodeForTypes = true [tool.pytest.ini_options] diff --git a/tests/test_fields.py b/tests/test_fields.py index 542f590..05dfc78 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -44,7 +44,7 @@ def test_set_none(db): def test_set_wrong_value(db): m = MyModel() with pytest.raises(ValidationError) as exc: - m.c_field = 1 + m.c_field = 1 # type:ignore (we really want the error here) m.save() assert list(exc.value) == ["ā€œ1ā€ must be a subclass of ."]