Skip to content

Commit

Permalink
Add IntegerChoices support
Browse files Browse the repository at this point in the history
  • Loading branch information
bellini666 committed Jul 17, 2021
1 parent 865b7a3 commit 8052f21
Show file tree
Hide file tree
Showing 7 changed files with 242 additions and 61 deletions.
26 changes: 18 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
![python version](https://img.shields.io/pypi/pyversions/django-choices-field.svg)
![django version](https://img.shields.io/pypi/djversions/django-choices-field.svg)

Django field that set/get django's new TextChoices enum.
Django field that set/get django's new TextChoices/IntegerChoices enum.

## Install

Expand All @@ -18,23 +18,33 @@ pip install django-choices-field

```python
from django.db import models
from django_choices_field import ChoicesField
from django_choices_field import TexChoicesField, IntegerChoicesField


class MyModel(models.Model):
class MyEnum(models.TextChoices):
class TextEnum(models.TextChoices):
FOO = "foo", "Foo Description"
BAR = "bar", "Bar Description"

c_field = ChoicesField(
choices_enum=MyEnum,
default=MyEnum.FOO,
class IntegerEnum(models.TextChoices):
FIRST = 1, "First Description"
SECOND = 2, "Second Description"

c_field = TextChoicesField(
choices_enum=TextEnum,
default=TextEnum.FOO,
)
i_field = IntegerChoicesField(
choices_enum=IntegerEnum,
default=IntegerEnum.FIRST,
)


obj = MyModel()
obj.c_field # MyModel.MyEnum.FOO
isinstance(obj.c_field, MyModel.MyEnum) # True
obj.c_field # MyModel.TextEnum.FOO
isinstance(obj.c_field, MyModel.TextEnum) # True
obj.i_field # MyModel.IntegerEnum.FIRST
isinstance(obj.i_field, MyModel.IntegerEnum) # True
```

## License
Expand Down
5 changes: 3 additions & 2 deletions django_choices_field/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .fields import ChoicesField # noqa
from .fields import IntegerChoicesField, TextChoicesField # noqa

__all__ = [
"ChoicesField",
"TextChoicesField",
"IntegerChoicesField",
]
50 changes: 45 additions & 5 deletions django_choices_field/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from django.db import models


class ChoicesField(models.CharField):
description = "Choices"
class TextChoicesField(models.CharField):
description = "TextChoices"
default_error_messages = {
"invalid": "“%(value)s” must be a subclass of %(enum)s.",
}
Expand All @@ -31,8 +31,48 @@ def to_python(self, value):
if value is None:
return None

if isinstance(value, self.choices_enum):
return value
try:
return self.choices_enum(value)
except ValueError:
raise ValidationError(
self.error_messages["invalid"],
code="invalid",
params={"value": value, "enum": self.choices_enum},
)

def from_db_value(self, value, expression, connection):
return self.to_python(value)

def get_prep_value(self, value):
value = super().get_prep_value(value)
return self.to_python(value)


class IntegerChoicesField(models.IntegerField):
description = "IntegerChoices"
default_error_messages = {
"invalid": "“%(value)s” must be a subclass of %(enum)s.",
}

def __init__(
self,
choices_enum: Type[models.IntegerChoices],
verbose_name: Optional[str] = None,
name: Optional[str] = None,
**kwargs,
):
self.choices_enum = choices_enum
kwargs["choices"] = choices_enum.choices
super().__init__(verbose_name=verbose_name, name=name, **kwargs)

def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
kwargs["choices_enum"] = self.choices_enum
return name, path, args, kwargs

def to_python(self, value):
if value is None:
return None

try:
return self.choices_enum(value)
Expand All @@ -48,4 +88,4 @@ def from_db_value(self, value, expression, connection):

def get_prep_value(self, value):
value = super().get_prep_value(value)
return value.value if value is not None else None
return self.to_python(value)
85 changes: 73 additions & 12 deletions django_choices_field/fields.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,17 @@ from typing import (
overload,
)

from django.db.models import Field, TextChoices
from django.db.models import Field, IntegerChoices, TextChoices

_Choice = Tuple[Any, str]
_ChoiceNamedGroup = Tuple[str, Iterable[_Choice]]
_FieldChoices = Iterable[Union[_Choice, _ChoiceNamedGroup]]
_ValidatorCallable = Callable[..., None]
_ErrorMessagesToOverride = Dict[str, Any]

_C = TypeVar("_C", bound="Optional[TextChoices]")

class ChoicesField(Generic[_C], Field[_C, _C]):
class TextChoicesField(Generic[_C], Field[_C, _C]):
@overload
def __new__(
cls,
Expand All @@ -35,7 +36,7 @@ class ChoicesField(Generic[_C], Field[_C, _C]):
blank: bool = ...,
null: Literal[False] = ...,
db_index: bool = ...,
default: Any = ...,
default: Optional[_C] = ...,
editable: bool = ...,
auto_created: bool = ...,
serialize: bool = ...,
Expand All @@ -48,7 +49,7 @@ class ChoicesField(Generic[_C], Field[_C, _C]):
db_tablespace: Optional[str] = ...,
validators: Iterable[_ValidatorCallable] = ...,
error_messages: Optional[_ErrorMessagesToOverride] = ...,
) -> ChoicesField[_C]: ...
) -> TextChoicesField[_C]: ...
@overload
def __new__(
cls,
Expand All @@ -61,7 +62,69 @@ class ChoicesField(Generic[_C], Field[_C, _C]):
blank: bool = ...,
null: Literal[True] = ...,
db_index: bool = ...,
default: Any = ...,
default: Optional[_C] = ...,
editable: bool = ...,
auto_created: bool = ...,
serialize: bool = ...,
unique_for_date: Optional[str] = ...,
unique_for_month: Optional[str] = ...,
unique_for_year: Optional[str] = ...,
choices: Optional[_FieldChoices] = ...,
help_text: str = ...,
db_column: Optional[str] = ...,
db_tablespace: Optional[str] = ...,
validators: Iterable[_ValidatorCallable] = ...,
error_messages: Optional[_ErrorMessagesToOverride] = ...,
) -> TextChoicesField[Optional[_C]]: ...
@overload
def __get__(self: TextChoicesField[_C], instance: Any, owner: Any) -> _C: ...
@overload
def __get__(
self: TextChoicesField[Optional[_C]], instance: Any, owner: Any
) -> Optional[_C]: ...

_I = TypeVar("_I", bound="Optional[IntegerChoices]")

class IntegerChoicesField(Generic[_I], Field[_I, _I]):
@overload
def __new__(
cls,
choices_enum: Type[_I],
verbose_name: Optional[Union[str, bytes]] = ...,
name: Optional[str] = ...,
primary_key: bool = ...,
max_length: Optional[int] = ...,
unique: bool = ...,
blank: bool = ...,
null: Literal[False] = ...,
db_index: bool = ...,
default: Optional[_I] = ...,
editable: bool = ...,
auto_created: bool = ...,
serialize: bool = ...,
unique_for_date: Optional[str] = ...,
unique_for_month: Optional[str] = ...,
unique_for_year: Optional[str] = ...,
choices: Optional[_FieldChoices] = ...,
help_text: str = ...,
db_column: Optional[str] = ...,
db_tablespace: Optional[str] = ...,
validators: Iterable[_ValidatorCallable] = ...,
error_messages: Optional[_ErrorMessagesToOverride] = ...,
) -> IntegerChoicesField[_I]: ...
@overload
def __new__(
cls,
choices_enum: Type[_I],
verbose_name: Optional[Union[str, bytes]] = ...,
name: Optional[str] = ...,
primary_key: bool = ...,
max_length: Optional[int] = ...,
unique: bool = ...,
blank: bool = ...,
null: Literal[True] = ...,
db_index: bool = ...,
default: Optional[_I] = ...,
editable: bool = ...,
auto_created: bool = ...,
serialize: bool = ...,
Expand All @@ -74,12 +137,10 @@ class ChoicesField(Generic[_C], Field[_C, _C]):
db_tablespace: Optional[str] = ...,
validators: Iterable[_ValidatorCallable] = ...,
error_messages: Optional[_ErrorMessagesToOverride] = ...,
) -> ChoicesField[Optional[_C]]: ...
) -> IntegerChoicesField[Optional[_I]]: ...
@overload
def __get__(self: ChoicesField[_C], instance: Any, owner: Any) -> _C: ...
def __get__(self: IntegerChoicesField[_I], instance: Any, owner: Any) -> _I: ...
@overload
def __get__(self: ChoicesField[Optional[_C]], instance: Any, owner: Any) -> Optional[_C]: ...
# @overload
# def __set__(self, instance: ChoicesField[_C], value: _C) -> None: ...
# @overload
# def __set__(self, instance: ChoicesField[Optional[_C]], value: Optional[_C]) -> None: ...
def __get__(
self: IntegerChoicesField[Optional[_I]], instance: Any, owner: Any
) -> Optional[_I]: ...
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[tool.poetry]
name = "django-choices-field"
version = "1.0"
description = "Django field that set/get django's new TextChoices enum."
version = "1.1"
description = "Django field that set/get django's new TextChoices/IntegerChoices enum."
authors = ["Thiago Bellini Ribeiro <[email protected]>"]
license = "MIT"
readme = "README.md"
Expand Down
30 changes: 21 additions & 9 deletions tests/models.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,32 @@
from django.db import models

from django_choices_field import ChoicesField
from django_choices_field import IntegerChoicesField, TextChoicesField


class MyModel(models.Model):
class MyEnum(models.TextChoices):
FOO = "foo", "Foo Description"
BAR = "bar", "Bar Description"
class TextEnum(models.TextChoices):
C_FOO = "foo", "T Foo Description"
C_BAR = "bar", "T Bar Description"

class IntegerEnum(models.IntegerChoices):
I_FOO = 1, "I Foo Description"
I_BAR = 2, "I Bar Description"

objects = models.Manager["MyModel"]()

c_field = ChoicesField(
choices_enum=MyEnum,
default=MyEnum.FOO,
c_field = TextChoicesField(
choices_enum=TextEnum,
default=TextEnum.C_FOO,
)
c_field_nullable = TextChoicesField(
choices_enum=TextEnum,
null=True,
)
i_field = IntegerChoicesField(
choices_enum=IntegerEnum,
default=IntegerEnum.I_FOO,
)
c_field_nullable = ChoicesField(
choices_enum=MyEnum,
i_field_nullable = IntegerChoicesField(
choices_enum=IntegerEnum,
null=True,
)
Loading

0 comments on commit 8052f21

Please sign in to comment.