Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion simple_parsing/helpers/serialization/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
is_tuple,
is_typevar,
is_union,
str2bool,
)

logger = getLogger(__name__)
Expand All @@ -36,10 +37,19 @@
_decoding_fns: Dict[Type[T], Callable[[Any], T]] = {
# the 'primitive' types are decoded using the type fn as a constructor.
t: t
for t in [str, float, int, bool, bytes]
for t in [str, float, int, bytes]
}


def decode_bool(v: Any) -> bool:
if isinstance(v, str):
return str2bool(v)
return bool(v)


_decoding_fns[bool] = decode_bool


def decode_field(field: Field, raw_value: Any) -> Any:
"""Converts a "raw" value (e.g. from json file) to the type of the `field`.

Expand Down
31 changes: 31 additions & 0 deletions test/test_issue_107.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
""" test for https://github.com/lebrice/SimpleParsing/issues/107 """
from dataclasses import dataclass
from typing import Any

import pytest

from simple_parsing.helpers.serialization.serializable import Serializable


@dataclass
class Foo(Serializable):
a: bool = False


@pytest.mark.parametrize(
"passed, expected",
[
("True", True),
("False", False),
(True, True),
(False, False),
("true", True),
("false", False),
("1", True),
("0", False),
(1, True),
(0, False),
],
)
def test_parsing_of_bool_works_as_expected(passed: Any, expected: bool):
assert Foo.from_dict({"a": passed}) == Foo(a=expected)