From e8ca7c1fb453a6f0b3de3268e2cea3434985c428 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20Mazzucotelli?= Date: Sun, 28 Apr 2024 13:25:51 +0200 Subject: [PATCH] feat: Support duty parameters annotated as type unions, with both old and modern syntax, even on Python 3.8 and 3.9 --- pyproject.toml | 1 + src/duty/validation.py | 25 ++++++++++++++++++++++--- tests/test_validation.py | 17 +++++++++++++++++ 3 files changed, 40 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index cc72960..43416ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ classifiers = [ "Typing :: Typed", ] dependencies = [ + "eval-type-backport; python_version < '3.10'", "failprint>=0.11,!=1.0.0", ] diff --git a/src/duty/validation.py b/src/duty/validation.py index 46bfdd2..735a2dc 100644 --- a/src/duty/validation.py +++ b/src/duty/validation.py @@ -9,9 +9,21 @@ import sys import textwrap +from contextlib import suppress from functools import cached_property from inspect import Parameter, Signature, signature -from typing import Any, Callable, Sequence +from typing import Any, Callable, ForwardRef, Sequence, Union, get_args, get_origin + +# TODO: Update once support for Python 3.9 is dropped. +if sys.version_info < (3, 10): + from eval_type_backport import eval_type_backport as eval_type + + union_types = (Union,) +else: + from types import UnionType + from typing import _eval_type as eval_type # type: ignore[attr-defined] + + union_types = (Union, UnionType) def to_bool(value: str) -> bool: @@ -40,6 +52,12 @@ def cast_arg(arg: Any, annotation: Any) -> Any: return arg if annotation is bool: annotation = to_bool + if get_origin(annotation) in union_types: + for sub_annotation in get_args(annotation): + if sub_annotation is type(None): + continue + with suppress(Exception): + return cast_arg(arg, sub_annotation) try: return annotation(arg) except Exception: # noqa: BLE001 @@ -187,9 +205,10 @@ def _get_params_caster(func: Callable, *args: Any, **kwargs: Any) -> ParamsCaste param.kind, default=param.default, annotation=( - eval( # noqa: PGH001,S307 - param.annotation, + eval_type( + ForwardRef(param.annotation) if isinstance(param.annotation, str) else param.annotation, exec_globals, + {}, ) if param.annotation is not Parameter.empty else type(param.default) diff --git a/tests/test_validation.py b/tests/test_validation.py index 23235b3..1a5f176 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -145,3 +145,20 @@ def func(ctx, a=0): # noqa: ANN202,ARG001,ANN001 caster = _get_params_caster(func, a="1") _, kwargs = caster.cast(a="1") assert kwargs == {"a": 1} + + +def test_validating_modern_annotations() -> None: + """Test modern type annotations in function signatures.""" + + def func(ctx, a: int | None = None): # noqa: ANN202,ARG001,ANN001 + ... + + caster = _get_params_caster(func, a=1) + _, kwargs = caster.cast(a="1") + assert kwargs == {"a": 1} + caster = _get_params_caster(func, a=None) + _, kwargs = caster.cast(a=None) + assert kwargs == {"a": None} + caster = _get_params_caster(func) + _, kwargs = caster.cast() + assert kwargs == {}