Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ classifiers = [
]
dependencies = [
"click >= 8.0.0",
"typing-extensions >= 3.7.4.3",
"typing-extensions >= 4.6.0",
]
readme = "README.md"
[project.urls]
Expand Down
139 changes: 139 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from datetime import datetime
from enum import Enum
from pathlib import Path
from uuid import UUID

import click
import pytest
from typer._typing import TypeAliasType
from typer.main import get_click_type
from typer.models import FileBinaryRead, FileTextWrite, ParameterInfo


def test_get_click_type_with_custom_click_type():
custom_click_type = click.INT
param_info = ParameterInfo(click_type=custom_click_type)
result = get_click_type(annotation=int, parameter_info=param_info)
assert result is custom_click_type


def test_get_click_type_with_custom_parser():
def mock_parser(x):
return 42

param_info = ParameterInfo(parser=mock_parser)
result = get_click_type(annotation=int, parameter_info=param_info)
assert isinstance(result, click.types.FuncParamType)
assert result.convert("42", None, None) == 42


def test_get_click_type_with_str_annotation():
param_info = ParameterInfo()
result = get_click_type(annotation=str, parameter_info=param_info)
assert result is click.STRING


def test_get_click_type_with_int_annotation_no_min_max():
param_info = ParameterInfo()
result = get_click_type(annotation=int, parameter_info=param_info)
assert result is click.INT


def test_get_click_type_with_int_annotation_with_min_max():
param_info = ParameterInfo(min=10, max=100)
result = get_click_type(annotation=int, parameter_info=param_info)
assert isinstance(result, click.IntRange)
assert result.min == 10
assert result.max == 100


def test_get_click_type_with_float_annotation_no_min_max():
param_info = ParameterInfo()
result = get_click_type(annotation=float, parameter_info=param_info)
assert result is click.FLOAT


def test_get_click_type_with_float_annotation_with_min_max():
param_info = ParameterInfo(min=0.1, max=10.5)
result = get_click_type(annotation=float, parameter_info=param_info)
assert isinstance(result, click.FloatRange)
assert result.min == 0.1
assert result.max == 10.5


def test_get_click_type_with_bool_annotation():
param_info = ParameterInfo()
result = get_click_type(annotation=bool, parameter_info=param_info)
assert result is click.BOOL


def test_get_click_type_with_uuid_annotation():
param_info = ParameterInfo()
result = get_click_type(annotation=UUID, parameter_info=param_info)
assert result is click.UUID


def test_get_click_type_with_datetime_annotation():
param_info = ParameterInfo(formats=["%Y-%m-%d"])
result = get_click_type(annotation=datetime, parameter_info=param_info)
assert isinstance(result, click.DateTime)
assert result.formats == ["%Y-%m-%d"]


def test_get_click_type_with_path_annotation():
param_info = ParameterInfo(resolve_path=True)
result = get_click_type(annotation=Path, parameter_info=param_info)
assert isinstance(result, click.Path)
assert result.resolve_path is True


def test_get_click_type_with_enum_annotation():
class Color(Enum):
RED = "red"
BLUE = "blue"

param_info = ParameterInfo()
result = get_click_type(annotation=Color, parameter_info=param_info)
assert isinstance(result, click.Choice)
assert result.choices == ("red", "blue")


def test_get_click_type_with_file_text_write_annotation():
param_info = ParameterInfo(mode="w", encoding="utf-8")
result = get_click_type(annotation=FileTextWrite, parameter_info=param_info)
assert isinstance(result, click.File)
assert result.mode == "w"
assert result.encoding == "utf-8"


def test_get_click_type_with_file_binary_read_annotation():
param_info = ParameterInfo(mode="rb")
result = get_click_type(annotation=FileBinaryRead, parameter_info=param_info)
assert isinstance(result, click.File)
assert result.mode == "rb"


def test_get_click_type_with_type_alias_type():
# define TypeAliasType
Name = TypeAliasType(name="Name", value=str)
Surname = TypeAliasType(name="Surname", value=Name)

param_info = ParameterInfo()
result = get_click_type(annotation=Name, parameter_info=param_info)
assert result is click.STRING

# recursive types
param_info = ParameterInfo()
result = get_click_type(annotation=Surname, parameter_info=param_info)
assert result is click.STRING


def test_get_click_type_with_unsupported_type():
class UnsupportedType:
pass

param_info = ParameterInfo()
with pytest.raises(
RuntimeError, match="Type not yet supported: <class '.*UnsupportedType.*'>"
):
get_click_type(annotation=UnsupportedType, parameter_info=param_info)
7 changes: 7 additions & 0 deletions typer/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ def is_union(tp: Optional[Type[Any]]) -> bool:
return tp is Union or tp is types.UnionType # noqa: E721


if sys.version_info < (3, 12):
from typing_extensions import TypeAliasType, TypeVar
else:
from typing import TypeAliasType, TypeVar

__all__ = (
"NoneType",
"is_none_type",
Expand All @@ -45,6 +50,8 @@ def is_union(tp: Optional[Type[Any]]) -> bool:
"is_union",
"Annotated",
"Literal",
"TypeAliasType",
"TypeVar",
"get_args",
"get_origin",
"get_type_hints",
Expand Down
26 changes: 23 additions & 3 deletions typer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,30 @@
from pathlib import Path
from traceback import FrameSummary, StackSummary
from types import TracebackType
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
)
from uuid import UUID

import click
from typer._types import TyperChoice

from ._typing import get_args, get_origin, is_literal_type, is_union, literal_values
from ._typing import (
TypeAliasType,
get_args,
get_origin,
is_literal_type,
is_union,
literal_values,
)
from .completion import get_completion_inspect_parameters
from .core import (
DEFAULT_MARKUP_MODE,
Expand Down Expand Up @@ -48,7 +65,7 @@
TyperInfo,
TyperPath,
)
from .utils import get_params_from_function
from .utils import get_original_type, get_params_from_function

_original_except_hook = sys.excepthook
_typer_developer_exception_attr_name = "__typer_developer_exception__"
Expand Down Expand Up @@ -697,6 +714,9 @@ def wrapper(**kwargs: Any) -> Any:
def get_click_type(
*, annotation: Any, parameter_info: ParameterInfo
) -> click.ParamType:
if isinstance(annotation, TypeAliasType):
annotation = get_original_type(annotation)

if parameter_info.click_type is not None:
return parameter_info.click_type

Expand Down
33 changes: 32 additions & 1 deletion typer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,19 @@
from copy import copy
from typing import Any, Callable, Dict, List, Tuple, Type, cast

from ._typing import Annotated, get_args, get_origin, get_type_hints
from ._typing import (
Annotated,
TypeAliasType,
TypeVar,
get_args,
get_origin,
get_type_hints,
)
from .models import ArgumentInfo, OptionInfo, ParameterInfo, ParamMeta

T = TypeVar("T")
TypeAliasTypeVar = TypeAliasType("TypeAliasTypeVar", value=T, type_params=(T,))


def _param_type_to_user_string(param_type: Type[ParameterInfo]) -> str:
# Render a `ParameterInfo` subclass for use in error messages.
Expand Down Expand Up @@ -188,3 +198,24 @@ def get_params_from_function(func: Callable[..., Any]) -> Dict[str, ParamMeta]:
name=param.name, default=default, annotation=annotation
)
return params


def get_original_type(alias: TypeAliasTypeVar[T]) -> T:
"""Return the original type of an alias.

Examples
--------
>>> Name = TypeAliasType(name="Name", value=str)
>>> Surname = TypeAliasType(name="Surname", value=Name)
>>> get_original_type(Name)
str
>>> get_original_type(Surname)
str
>>> get_original_type(int)
int
"""
otype = alias
while isinstance(otype, TypeAliasType):
otype = otype.__value__

return otype
Loading