Skip to content

Commit

Permalink
fix: fix answer validation for type: str questions
Browse files Browse the repository at this point in the history
  • Loading branch information
sisp authored and yajo committed Oct 1, 2023
1 parent e38caa2 commit e97d673
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 6 deletions.
20 changes: 19 additions & 1 deletion copier/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import stat
import sys
from contextlib import suppress
from decimal import Decimal
from enum import Enum
from importlib.metadata import version
from pathlib import Path
from types import TracebackType
Expand Down Expand Up @@ -90,7 +92,23 @@ def printf_exception(
print(HLINE, file=sys.stderr)


def cast_str_to_bool(value: Any) -> bool:
def cast_to_str(value: Any) -> str:
"""Parse anything to str.
Params:
value:
Anything to be casted to a str.
"""
if isinstance(value, str):
return value.value if isinstance(value, Enum) else value
if isinstance(value, (float, int, Decimal)):
return str(value)
if isinstance(value, (bytes, bytearray)):
return value.decode()
raise ValueError(f"Could not convert {value} to string")


def cast_to_bool(value: Any) -> bool:
"""Parse anything to bool.
Params:
Expand Down
10 changes: 5 additions & 5 deletions copier/user_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from questionary.prompts.common import Choice

from .errors import InvalidTypeError, UserMessageError
from .tools import cast_str_to_bool, force_str_end
from .tools import cast_to_bool, cast_to_str, force_str_end
from .types import MISSING, AnyByStrDict, MissingType, OptStr, OptStrOrPath, StrOrPath


Expand Down Expand Up @@ -377,7 +377,7 @@ def get_type_name(self) -> str:

def get_multiline(self) -> bool:
"""Get the value for multiline."""
return cast_str_to_bool(self.render_value(self.multiline))
return cast_to_bool(self.render_value(self.multiline))

def validate_answer(self, answer) -> bool:
"""Validate user answer."""
Expand All @@ -396,7 +396,7 @@ def validate_answer(self, answer) -> bool:

def get_when(self) -> bool:
"""Get skip condition for question."""
return cast_str_to_bool(self.render_value(self.when))
return cast_to_bool(self.render_value(self.when))

def render_value(
self, value: Any, extra_answers: Optional[AnyByStrDict] = None
Expand Down Expand Up @@ -460,10 +460,10 @@ def load_answersfile_data(


CAST_STR_TO_NATIVE: Mapping[str, Callable] = {
"bool": cast_str_to_bool,
"bool": cast_to_bool,
"float": float,
"int": int,
"json": json.loads,
"str": str,
"str": cast_to_str,
"yaml": parse_yaml_string,
}
15 changes: 15 additions & 0 deletions tests/test_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import stat
import sys
from contextlib import nullcontext as does_not_raise
from decimal import Decimal
from enum import Enum
from pathlib import Path
from typing import Any, ContextManager, List

Expand Down Expand Up @@ -455,6 +457,19 @@ def test_value_with_forward_slash(tmp_path_factory: pytest.TempPathFactory) -> N
"1",
does_not_raise(),
),
({"type": "str"}, "string", does_not_raise()),
({"type": "str"}, b"bytes", does_not_raise()),
({"type": "str"}, bytearray("abc", "utf-8"), does_not_raise()),
({"type": "str"}, 1, does_not_raise()),
({"type": "str"}, 1.1, does_not_raise()),
({"type": "str"}, True, does_not_raise()),
({"type": "str"}, False, does_not_raise()),
({"type": "str"}, Decimal(1.1), does_not_raise()),
({"type": "str"}, Enum("A", ["a", "b"], type=str).a, does_not_raise()), # type: ignore
({"type": "str"}, Enum("A", ["a", "b"]).a, pytest.raises(ValueError)), # type: ignore
({"type": "str"}, object(), pytest.raises(ValueError)),
({"type": "str"}, {}, pytest.raises(ValueError)),
({"type": "str"}, [], pytest.raises(ValueError)),
(
{
"type": "str",
Expand Down

0 comments on commit e97d673

Please sign in to comment.