From 68a66e862615cb22e25aa3f02fb98acac663d8d9 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Mon, 6 Jun 2022 13:48:36 -0400 Subject: [PATCH] Fix issue 107 (decoding of 'False' -> True) Signed-off-by: Fabrice Normandin --- .../helpers/serialization/decoding.py | 12 ++++++- test/test_issue_107.py | 31 +++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) create mode 100644 test/test_issue_107.py diff --git a/simple_parsing/helpers/serialization/decoding.py b/simple_parsing/helpers/serialization/decoding.py index 47f2d512..9934d1c5 100644 --- a/simple_parsing/helpers/serialization/decoding.py +++ b/simple_parsing/helpers/serialization/decoding.py @@ -24,6 +24,7 @@ is_tuple, is_typevar, is_union, + str2bool, ) logger = getLogger(__name__) @@ -36,10 +37,19 @@ _decoding_fns: Dict[Type[T], Callable[[Any], T]] = { # the 'primitive' types are decoded using the type fn as a constructor. t: t - for t in [str, float, int, bool, bytes] + for t in [str, float, int, bytes] } +def decode_bool(v: Any) -> bool: + if isinstance(v, str): + return str2bool(v) + return bool(v) + + +_decoding_fns[bool] = decode_bool + + def decode_field(field: Field, raw_value: Any) -> Any: """Converts a "raw" value (e.g. from json file) to the type of the `field`. diff --git a/test/test_issue_107.py b/test/test_issue_107.py new file mode 100644 index 00000000..608edca0 --- /dev/null +++ b/test/test_issue_107.py @@ -0,0 +1,31 @@ +""" test for https://github.com/lebrice/SimpleParsing/issues/107 """ +from dataclasses import dataclass +from typing import Any + +import pytest + +from simple_parsing.helpers.serialization.serializable import Serializable + + +@dataclass +class Foo(Serializable): + a: bool = False + + +@pytest.mark.parametrize( + "passed, expected", + [ + ("True", True), + ("False", False), + (True, True), + (False, False), + ("true", True), + ("false", False), + ("1", True), + ("0", False), + (1, True), + (0, False), + ], +) +def test_parsing_of_bool_works_as_expected(passed: Any, expected: bool): + assert Foo.from_dict({"a": passed}) == Foo(a=expected)