Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
Make PyText compatible with Any type
Browse files Browse the repository at this point in the history
Summary:
As discussed in https://fb.workplace.com/groups/300451907202972/permalink/566773450570815/, it seems like there is no support for the following:
 - `Any` or `Dict[str, Any]`: PyText serailize.py complains
 - `Dict[str, Union[str, float]]`: bad support of Union in Dict at flow level

 Anyway, this is blocking and I don't see stark downside of adding interim support for `Any` in `serialize.py`. This typing issue is blocking the large stack.

Reviewed By: kmalik22

Differential Revision: D19699214

fbshipit-source-id: 370a7d138aae3c9a96d0e302e1657585242098e6
  • Loading branch information
jessemin authored and facebook-github-bot committed Feb 5, 2020
1 parent d93b589 commit 4aa8ffd
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 3 deletions.
19 changes: 17 additions & 2 deletions pytext/config/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from enum import Enum
from typing import Dict, List, Tuple, Union
from typing import Any, Dict, List, Tuple, Union

from pytext.common.utils import eprint

Expand Down Expand Up @@ -94,6 +94,19 @@ def _union_from_json(subclasses, json_obj):
) from e


def _any_from_json(cls, json_obj):
if _is_dict(json_obj):
if len(json_obj) == 1:
type_name, value = list(json_obj.keys())[0], list(json_obj.values())[0]
assert (
type_name == type(value).__name__
), f"type of value mismatches: {type_name} vs. {type(value).__name__} from {value}"
return value
else:
raise TypeError("PyText Config currently don't support this")
return json_obj


def _is_optional(cls):
return _get_class_type(cls) == Union and type(None) in cls.__args__

Expand All @@ -118,6 +131,8 @@ def _value_from_json(cls, value):
# nested config
elif hasattr(cls, "_fields"):
return config_from_json(cls, value)
elif cls_type == Any:
return _any_from_json(cls, value)
elif cls_type == Union:
return _union_from_json(cls.__args__, value)
elif issubclass(cls_type, Enum):
Expand Down Expand Up @@ -221,7 +236,7 @@ def _value_to_json(cls, value):
elif _is_optional(cls) and len(cls.__args__) == 2:
sub_cls = cls.__args__[0] if type(None) != cls.__args__[0] else cls.__args__[1]
return _value_to_json(sub_cls, value)
elif cls_type == Union or getattr(cls, "__EXPANSIBLE__", False):
elif cls_type == Any or cls_type == Union or getattr(cls, "__EXPANSIBLE__", False):
real_cls = type(value)
if hasattr(real_cls, "_fields"):
value = config_to_json(real_cls, value)
Expand Down
26 changes: 25 additions & 1 deletion pytext/config/test/serialize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,22 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import unittest
from typing import Union
from typing import Any, Dict, Union

from pytext.config import serialize


SAMPLE_INT_JSON = {"int": 6}
SAMPLE_UNION_CLS = Union[str, int]

SAMPLE_DICT_WITH_ANY_CLS = Dict[str, Any]
DICT_WITH_ANY: SAMPLE_DICT_WITH_ANY_CLS = {"lr": 0.1, "type": "FedAvg"}
SAMPLE_DICT_WITH_ANY_JSON = {"lr": {"float": 0.1}, "type": {"str": "FedAvg"}}

SAMPLE_ANY_CLS = Any
MULTI_TYPE_LIST: Any = [1, "test", 0.01]
SAMPLE_ANY: Any = {"list": MULTI_TYPE_LIST}


class SerializeTest(unittest.TestCase):
def test_value_from_json(self):
Expand All @@ -21,3 +29,19 @@ def test_value_to_json(self):
print()
json = serialize._value_to_json(Union[str, int], 6)
self.assertEqual(SAMPLE_INT_JSON, json)

def test_value_to_json_for_class_type_any(self):
json = serialize._value_to_json(Dict[str, Any], DICT_WITH_ANY)
self.assertEqual(json, SAMPLE_DICT_WITH_ANY_JSON)

json = serialize._value_to_json(Any, [1, "test", 0.01])
self.assertEqual(json, SAMPLE_ANY)

def test_value_from_json_for_class_type_any(self):
value = serialize._value_from_json(
SAMPLE_DICT_WITH_ANY_CLS, SAMPLE_DICT_WITH_ANY_JSON
)
self.assertEqual(DICT_WITH_ANY, value)

value = serialize._value_from_json(SAMPLE_ANY_CLS, SAMPLE_ANY)
self.assertEqual([1, "test", 0.01], value)

0 comments on commit 4aa8ffd

Please sign in to comment.