Skip to content

Commit

Permalink
feat: Support duty parameters annotated as type unions, with both old…
Browse files Browse the repository at this point in the history
… and modern syntax, even on Python 3.8 and 3.9
  • Loading branch information
pawamoy committed Apr 28, 2024
1 parent c5c46dd commit e8ca7c1
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 3 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ classifiers = [
"Typing :: Typed",
]
dependencies = [
"eval-type-backport; python_version < '3.10'",
"failprint>=0.11,!=1.0.0",
]

Expand Down
25 changes: 22 additions & 3 deletions src/duty/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,21 @@

import sys
import textwrap
from contextlib import suppress
from functools import cached_property
from inspect import Parameter, Signature, signature
from typing import Any, Callable, Sequence
from typing import Any, Callable, ForwardRef, Sequence, Union, get_args, get_origin

# TODO: Update once support for Python 3.9 is dropped.
if sys.version_info < (3, 10):
from eval_type_backport import eval_type_backport as eval_type

union_types = (Union,)
else:
from types import UnionType
from typing import _eval_type as eval_type # type: ignore[attr-defined]

union_types = (Union, UnionType)


def to_bool(value: str) -> bool:
Expand Down Expand Up @@ -40,6 +52,12 @@ def cast_arg(arg: Any, annotation: Any) -> Any:
return arg
if annotation is bool:
annotation = to_bool
if get_origin(annotation) in union_types:
for sub_annotation in get_args(annotation):
if sub_annotation is type(None):
continue
with suppress(Exception):
return cast_arg(arg, sub_annotation)
try:
return annotation(arg)
except Exception: # noqa: BLE001
Expand Down Expand Up @@ -187,9 +205,10 @@ def _get_params_caster(func: Callable, *args: Any, **kwargs: Any) -> ParamsCaste
param.kind,
default=param.default,
annotation=(
eval( # noqa: PGH001,S307
param.annotation,
eval_type(
ForwardRef(param.annotation) if isinstance(param.annotation, str) else param.annotation,
exec_globals,
{},
)
if param.annotation is not Parameter.empty
else type(param.default)
Expand Down
17 changes: 17 additions & 0 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,20 @@ def func(ctx, a=0): # noqa: ANN202,ARG001,ANN001
caster = _get_params_caster(func, a="1")
_, kwargs = caster.cast(a="1")
assert kwargs == {"a": 1}


def test_validating_modern_annotations() -> None:
"""Test modern type annotations in function signatures."""

def func(ctx, a: int | None = None): # noqa: ANN202,ARG001,ANN001
...

caster = _get_params_caster(func, a=1)
_, kwargs = caster.cast(a="1")
assert kwargs == {"a": 1}
caster = _get_params_caster(func, a=None)
_, kwargs = caster.cast(a=None)
assert kwargs == {"a": None}
caster = _get_params_caster(func)
_, kwargs = caster.cast()
assert kwargs == {}

0 comments on commit e8ca7c1

Please sign in to comment.