Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ classifiers = [
]
dependencies = [
"click >= 8.0.0",
"typing-extensions >= 3.7.4.3",
"typing-extensions >= 4.6.0",
]
readme = "README.md"
[project.urls]
Expand Down
139 changes: 139 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from datetime import datetime
from enum import Enum
from pathlib import Path
from uuid import UUID

import click
import pytest
from typer.main import get_click_type
from typer.models import FileBinaryRead, FileTextWrite, ParameterInfo
from typing_extensions import TypeAliasType


def test_get_click_type_with_custom_click_type():
custom_click_type = click.INT
param_info = ParameterInfo(click_type=custom_click_type)
result = get_click_type(annotation=int, parameter_info=param_info)
assert result is custom_click_type


def test_get_click_type_with_custom_parser():
def mock_parser(x):
return 42

param_info = ParameterInfo(parser=mock_parser)
result = get_click_type(annotation=int, parameter_info=param_info)
assert isinstance(result, click.types.FuncParamType)
assert result.convert("42", None, None) == 42


def test_get_click_type_with_str_annotation():
param_info = ParameterInfo()
result = get_click_type(annotation=str, parameter_info=param_info)
assert result is click.STRING


def test_get_click_type_with_int_annotation_no_min_max():
param_info = ParameterInfo()
result = get_click_type(annotation=int, parameter_info=param_info)
assert result is click.INT


def test_get_click_type_with_int_annotation_with_min_max():
param_info = ParameterInfo(min=10, max=100)
result = get_click_type(annotation=int, parameter_info=param_info)
assert isinstance(result, click.IntRange)
assert result.min == 10
assert result.max == 100


def test_get_click_type_with_float_annotation_no_min_max():
param_info = ParameterInfo()
result = get_click_type(annotation=float, parameter_info=param_info)
assert result is click.FLOAT


def test_get_click_type_with_float_annotation_with_min_max():
param_info = ParameterInfo(min=0.1, max=10.5)
result = get_click_type(annotation=float, parameter_info=param_info)
assert isinstance(result, click.FloatRange)
assert result.min == 0.1
assert result.max == 10.5


def test_get_click_type_with_bool_annotation():
param_info = ParameterInfo()
result = get_click_type(annotation=bool, parameter_info=param_info)
assert result is click.BOOL


def test_get_click_type_with_uuid_annotation():
param_info = ParameterInfo()
result = get_click_type(annotation=UUID, parameter_info=param_info)
assert result is click.UUID


def test_get_click_type_with_datetime_annotation():
param_info = ParameterInfo(formats=["%Y-%m-%d"])
result = get_click_type(annotation=datetime, parameter_info=param_info)
assert isinstance(result, click.DateTime)
assert result.formats == ["%Y-%m-%d"]


def test_get_click_type_with_path_annotation():
param_info = ParameterInfo(resolve_path=True)
result = get_click_type(annotation=Path, parameter_info=param_info)
assert isinstance(result, click.Path)
assert result.resolve_path is True


def test_get_click_type_with_enum_annotation():
class Color(Enum):
RED = "red"
BLUE = "blue"

param_info = ParameterInfo()
result = get_click_type(annotation=Color, parameter_info=param_info)
assert isinstance(result, click.Choice)
assert result.choices == ["red", "blue"]


def test_get_click_type_with_file_text_write_annotation():
param_info = ParameterInfo(mode="w", encoding="utf-8")
result = get_click_type(annotation=FileTextWrite, parameter_info=param_info)
assert isinstance(result, click.File)
assert result.mode == "w"
assert result.encoding == "utf-8"


def test_get_click_type_with_file_binary_read_annotation():
param_info = ParameterInfo(mode="rb")
result = get_click_type(annotation=FileBinaryRead, parameter_info=param_info)
assert isinstance(result, click.File)
assert result.mode == "rb"


def test_get_click_type_with_type_alias_type():
# define TypeAliasType
Name = TypeAliasType(name="Name", value=str)
Surname = TypeAliasType(name="Surname", value=Name)

param_info = ParameterInfo()
result = get_click_type(annotation=Name, parameter_info=param_info)
assert result is click.STRING

# recursive types
param_info = ParameterInfo()
result = get_click_type(annotation=Surname, parameter_info=param_info)
assert result is click.STRING


def test_get_click_type_with_unsupported_type():
class UnsupportedType:
pass

param_info = ParameterInfo()
with pytest.raises(
RuntimeError, match="Type not yet supported: <class '.*UnsupportedType.*'>"
):
get_click_type(annotation=UnsupportedType, parameter_info=param_info)
19 changes: 16 additions & 3 deletions typer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,21 @@
from pathlib import Path
from traceback import FrameSummary, StackSummary
from types import TracebackType
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
)
from uuid import UUID

import click
from typing_extensions import get_args, get_origin
from typing_extensions import TypeAliasType, get_args, get_origin

from ._typing import is_union
from .completion import get_completion_inspect_parameters
Expand Down Expand Up @@ -43,7 +53,7 @@
Required,
TyperInfo,
)
from .utils import get_params_from_function
from .utils import get_original_type, get_params_from_function

try:
import rich
Expand Down Expand Up @@ -710,6 +720,9 @@ def wrapper(**kwargs: Any) -> Any:
def get_click_type(
*, annotation: Any, parameter_info: ParameterInfo
) -> click.ParamType:
if isinstance(annotation, TypeAliasType):
annotation = get_original_type(annotation)

if parameter_info.click_type is not None:
return parameter_info.click_type

Expand Down
32 changes: 31 additions & 1 deletion typer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,19 @@
from copy import copy
from typing import Any, Callable, Dict, List, Tuple, Type, cast

from typing_extensions import Annotated, get_args, get_origin, get_type_hints
from typing_extensions import (
Annotated,
TypeAliasType,
TypeVar,
get_args,
get_origin,
get_type_hints,
)

from .models import ArgumentInfo, OptionInfo, ParameterInfo, ParamMeta

T = TypeVar("T")


def _param_type_to_user_string(param_type: Type[ParameterInfo]) -> str:
# Render a `ParameterInfo` subclass for use in error messages.
Expand Down Expand Up @@ -189,3 +198,24 @@ def get_params_from_function(func: Callable[..., Any]) -> Dict[str, ParamMeta]:
name=param.name, default=default, annotation=annotation
)
return params


def get_original_type(alias: TypeAliasType) -> T:
"""Return the original type of an alias.

Examples
--------
>>> Name = TypeAliasType(name="Name", value=str)
>>> Surname = TypeAliasType(name="Surname", value=Name)
>>> get_original_type(Name)
str
>>> get_original_type(Surname)
str
>>> get_original_type(int)
int
"""
otype = alias
while isinstance(otype, TypeAliasType):
otype = otype.__value__

return otype