diff --git a/pyproject.toml b/pyproject.toml index ac89bbd9bb..48615c557a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ classifiers = [ ] dependencies = [ "click >= 8.0.0", - "typing-extensions >= 3.7.4.3", + "typing-extensions >= 4.6.0", ] readme = "README.md" [project.urls] diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 0000000000..d047fbc2e3 --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,139 @@ +from datetime import datetime +from enum import Enum +from pathlib import Path +from uuid import UUID + +import click +import pytest +from typer._typing import TypeAliasType +from typer.main import get_click_type +from typer.models import FileBinaryRead, FileTextWrite, ParameterInfo + + +def test_get_click_type_with_custom_click_type(): + custom_click_type = click.INT + param_info = ParameterInfo(click_type=custom_click_type) + result = get_click_type(annotation=int, parameter_info=param_info) + assert result is custom_click_type + + +def test_get_click_type_with_custom_parser(): + def mock_parser(x): + return 42 + + param_info = ParameterInfo(parser=mock_parser) + result = get_click_type(annotation=int, parameter_info=param_info) + assert isinstance(result, click.types.FuncParamType) + assert result.convert("42", None, None) == 42 + + +def test_get_click_type_with_str_annotation(): + param_info = ParameterInfo() + result = get_click_type(annotation=str, parameter_info=param_info) + assert result is click.STRING + + +def test_get_click_type_with_int_annotation_no_min_max(): + param_info = ParameterInfo() + result = get_click_type(annotation=int, parameter_info=param_info) + assert result is click.INT + + +def test_get_click_type_with_int_annotation_with_min_max(): + param_info = ParameterInfo(min=10, max=100) + result = get_click_type(annotation=int, parameter_info=param_info) + assert isinstance(result, click.IntRange) + assert result.min == 10 + assert result.max == 100 + + +def test_get_click_type_with_float_annotation_no_min_max(): + param_info = ParameterInfo() + result = get_click_type(annotation=float, parameter_info=param_info) + assert result is click.FLOAT + + +def test_get_click_type_with_float_annotation_with_min_max(): + param_info = ParameterInfo(min=0.1, max=10.5) + result = get_click_type(annotation=float, parameter_info=param_info) + assert isinstance(result, click.FloatRange) + assert result.min == 0.1 + assert result.max == 10.5 + + +def test_get_click_type_with_bool_annotation(): + param_info = ParameterInfo() + result = get_click_type(annotation=bool, parameter_info=param_info) + assert result is click.BOOL + + +def test_get_click_type_with_uuid_annotation(): + param_info = ParameterInfo() + result = get_click_type(annotation=UUID, parameter_info=param_info) + assert result is click.UUID + + +def test_get_click_type_with_datetime_annotation(): + param_info = ParameterInfo(formats=["%Y-%m-%d"]) + result = get_click_type(annotation=datetime, parameter_info=param_info) + assert isinstance(result, click.DateTime) + assert result.formats == ["%Y-%m-%d"] + + +def test_get_click_type_with_path_annotation(): + param_info = ParameterInfo(resolve_path=True) + result = get_click_type(annotation=Path, parameter_info=param_info) + assert isinstance(result, click.Path) + assert result.resolve_path is True + + +def test_get_click_type_with_enum_annotation(): + class Color(Enum): + RED = "red" + BLUE = "blue" + + param_info = ParameterInfo() + result = get_click_type(annotation=Color, parameter_info=param_info) + assert isinstance(result, click.Choice) + assert result.choices == ("red", "blue") + + +def test_get_click_type_with_file_text_write_annotation(): + param_info = ParameterInfo(mode="w", encoding="utf-8") + result = get_click_type(annotation=FileTextWrite, parameter_info=param_info) + assert isinstance(result, click.File) + assert result.mode == "w" + assert result.encoding == "utf-8" + + +def test_get_click_type_with_file_binary_read_annotation(): + param_info = ParameterInfo(mode="rb") + result = get_click_type(annotation=FileBinaryRead, parameter_info=param_info) + assert isinstance(result, click.File) + assert result.mode == "rb" + + +def test_get_click_type_with_type_alias_type(): + # define TypeAliasType + Name = TypeAliasType(name="Name", value=str) + Surname = TypeAliasType(name="Surname", value=Name) + + param_info = ParameterInfo() + result = get_click_type(annotation=Name, parameter_info=param_info) + assert result is click.STRING + + # recursive types + param_info = ParameterInfo() + result = get_click_type(annotation=Surname, parameter_info=param_info) + assert result is click.STRING + + +def test_get_click_type_with_unsupported_type(): + class UnsupportedType: + pass + + param_info = ParameterInfo() + with pytest.raises( + RuntimeError, match="Type not yet supported: " + ): + get_click_type(annotation=UnsupportedType, parameter_info=param_info) diff --git a/typer/_typing.py b/typer/_typing.py index 093388cd8d..d118a91e45 100644 --- a/typer/_typing.py +++ b/typer/_typing.py @@ -36,6 +36,11 @@ def is_union(tp: Optional[Type[Any]]) -> bool: return tp is Union or tp is types.UnionType # noqa: E721 +if sys.version_info < (3, 12): + from typing_extensions import TypeAliasType, TypeVar +else: + from typing import TypeAliasType, TypeVar + __all__ = ( "NoneType", "is_none_type", @@ -45,6 +50,8 @@ def is_union(tp: Optional[Type[Any]]) -> bool: "is_union", "Annotated", "Literal", + "TypeAliasType", + "TypeVar", "get_args", "get_origin", "get_type_hints", diff --git a/typer/main.py b/typer/main.py index 71a25e6c4b..fb507f90d1 100644 --- a/typer/main.py +++ b/typer/main.py @@ -11,13 +11,30 @@ from pathlib import Path from traceback import FrameSummary, StackSummary from types import TracebackType -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Tuple, + Type, + Union, +) from uuid import UUID import click from typer._types import TyperChoice -from ._typing import get_args, get_origin, is_literal_type, is_union, literal_values +from ._typing import ( + TypeAliasType, + get_args, + get_origin, + is_literal_type, + is_union, + literal_values, +) from .completion import get_completion_inspect_parameters from .core import ( DEFAULT_MARKUP_MODE, @@ -48,7 +65,7 @@ TyperInfo, TyperPath, ) -from .utils import get_params_from_function +from .utils import get_original_type, get_params_from_function _original_except_hook = sys.excepthook _typer_developer_exception_attr_name = "__typer_developer_exception__" @@ -697,6 +714,9 @@ def wrapper(**kwargs: Any) -> Any: def get_click_type( *, annotation: Any, parameter_info: ParameterInfo ) -> click.ParamType: + if isinstance(annotation, TypeAliasType): + annotation = get_original_type(annotation) + if parameter_info.click_type is not None: return parameter_info.click_type diff --git a/typer/utils.py b/typer/utils.py index 81dc4dd61d..e110008011 100644 --- a/typer/utils.py +++ b/typer/utils.py @@ -3,9 +3,19 @@ from copy import copy from typing import Any, Callable, Dict, List, Tuple, Type, cast -from ._typing import Annotated, get_args, get_origin, get_type_hints +from ._typing import ( + Annotated, + TypeAliasType, + TypeVar, + get_args, + get_origin, + get_type_hints, +) from .models import ArgumentInfo, OptionInfo, ParameterInfo, ParamMeta +T = TypeVar("T") +TypeAliasTypeVar = TypeAliasType("TypeAliasTypeVar", value=T, type_params=(T,)) + def _param_type_to_user_string(param_type: Type[ParameterInfo]) -> str: # Render a `ParameterInfo` subclass for use in error messages. @@ -188,3 +198,24 @@ def get_params_from_function(func: Callable[..., Any]) -> Dict[str, ParamMeta]: name=param.name, default=default, annotation=annotation ) return params + + +def get_original_type(alias: TypeAliasTypeVar[T]) -> T: + """Return the original type of an alias. + + Examples + -------- + >>> Name = TypeAliasType(name="Name", value=str) + >>> Surname = TypeAliasType(name="Surname", value=Name) + >>> get_original_type(Name) + str + >>> get_original_type(Surname) + str + >>> get_original_type(int) + int + """ + otype = alias + while isinstance(otype, TypeAliasType): + otype = otype.__value__ + + return otype