Skip to content

Commit

Permalink
Add option to return none on unknown union variants (#152)
Browse files Browse the repository at this point in the history
Co-authored-by: Mark Greatorex <[email protected]>
  • Loading branch information
markgrex and Mark Greatorex authored May 22, 2023
1 parent edcc41f commit ad533fa
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 37 deletions.
25 changes: 21 additions & 4 deletions conjure_python_client/_http/requests_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,26 @@ class Service(object):
_connect_timeout = None # type: float
_read_timeout = None # type: float
_verify = None # type: str
_return_none_for_unknown_union_types = False # type: bool

def __init__(
self, requests_session, uris, _connect_timeout, _read_timeout, _verify
self,
requests_session,
uris,
_connect_timeout,
_read_timeout,
_verify,
_return_none_for_unknown_union_types=False
):
# type: (requests.Session, List[str], float, float, str) -> None
# type: (requests.Session, List[str], float, float, str, bool) -> None
self._requests_session = requests_session
self._uris = uris
self._connect_timeout = _connect_timeout
self._read_timeout = _read_timeout
self._verify = _verify
self._return_none_for_unknown_union_types = (
_return_none_for_unknown_union_types
)

@property
def _uri(self):
Expand Down Expand Up @@ -156,8 +166,14 @@ def get_backoff_time(self):
class RequestsClient(object):

@classmethod
def create(cls, service_class, user_agent, service_config):
# type: (Type[T], str, ServiceConfiguration) -> T
def create(
cls,
service_class,
user_agent,
service_config,
return_none_for_unknown_union_types=False
):
# type: (Type[T], str, ServiceConfiguration, bool) -> T
# setup retry to match java remoting
# https://github.com/palantir/http-remoting/tree/3.12.0#quality-of-service-retry-failover-throttling
retry = RetryWithJitter(
Expand All @@ -182,6 +198,7 @@ def create(cls, service_class, user_agent, service_config):
service_config.connect_timeout,
service_config.read_timeout,
verify,
return_none_for_unknown_union_types,
)


Expand Down
145 changes: 112 additions & 33 deletions conjure_python_client/_serde/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,19 @@ class ConjureDecoder(object):
"""Decodes json into a conjure object"""

@classmethod
def decode_conjure_bean_type(cls, obj, conjure_type):
def decode_conjure_bean_type(
cls, obj, conjure_type, return_none_for_unknown_union_types=False
):
"""Decodes json into a conjure bean type (a plain bean, not enum
or union).
Args:
obj: the json object to decode
conjure_type: a class object which is the bean type
we're decoding into
return_none_for_unknown_union_types: if set to True, returns None
instead of raising an exception when an unknown union type is
encountered
Returns:
A instance of a bean of type conjure_type.
"""
Expand All @@ -53,8 +58,9 @@ def decode_conjure_bean_type(cls, obj, conjure_type):
else:
value = obj[field_identifier]
field_type = field_definition.field_type
deserialized[python_arg_name] = \
cls.do_decode(value, field_type)
deserialized[python_arg_name] = cls.do_decode(
value, field_type, return_none_for_unknown_union_types
)
return conjure_type(**deserialized)

@classmethod
Expand All @@ -74,13 +80,18 @@ def check_null_field(
)

@classmethod
def decode_conjure_union_type(cls, obj, conjure_type):
def decode_conjure_union_type(
cls, obj, conjure_type, return_none_for_unknown_union_types=False
):
"""Decodes json into a conjure union type.
Args:
obj: the json object to decode
conjure_type: a class object which is the union type
we're decoding into
return_none_for_unknown_union_types: if set to True, returns None
instead of raising an exception when an unknown union type is
encountered
Returns:
An instance of type conjure_type.
"""
Expand All @@ -91,11 +102,14 @@ def decode_conjure_union_type(cls, obj, conjure_type):
conjure_field_definition = conjure_field
break
else:
raise ValueError(
"unknown union type {0} for {1}".format(
type_of_union, conjure_type
if return_none_for_unknown_union_types:
return None
else:
raise ValueError(
"unknown union type {0} for {1}".format(
type_of_union, conjure_type
)
)
)

deserialized = {} # type: Dict[str, Any]
if type_of_union not in obj or obj[type_of_union] is None:
Expand All @@ -104,7 +118,9 @@ def decode_conjure_union_type(cls, obj, conjure_type):
else:
value = obj[type_of_union]
field_type = conjure_field_definition.field_type
deserialized[attribute] = cls.do_decode(value, field_type)
deserialized[attribute] = cls.do_decode(
value, field_type, return_none_for_unknown_union_types
)

# for backwards compatibility with conjure-python,
# only pass in type_of_union if it is expected
Expand Down Expand Up @@ -141,6 +157,7 @@ def decode_dict(
obj, # type: Dict[Any, Any]
key_type, # ConjureTypeType
item_type, # ConjureTypeType
return_none_for_unknown_union_types=False # bool,
): # type: (...) -> Dict[Any, Any]
"""Decodes json into a dictionary, handling conversion of the
keys/values (the keys/values may themselves require conversion).
Expand All @@ -151,6 +168,9 @@ def decode_dict(
of the keys in this dict
item_type: a class object which is the conjure type
of the values in this dict
return_none_for_unknown_union_types: if set to True, returns None
instead of raising an exception when an unknown union type is
encountered
Returns:
A python dictionary, where the keys are instances of type key_type
and the values are of type value_type.
Expand All @@ -160,50 +180,83 @@ def decode_dict(
if key_type == str or isinstance(key_type, BinaryType) \
or (inspect.isclass(key_type)
and issubclass(key_type, ConjureEnumType)):
return dict((
(cls.do_decode(x[0], key_type), cls.do_decode(x[1], item_type))
return dict(((
cls.do_decode(
x[0], key_type, return_none_for_unknown_union_types
),
cls.do_decode(
x[1], item_type, return_none_for_unknown_union_types
)
)
for x in obj.items()))

return dict((
(cls.do_decode(json.loads(x[0]), key_type),
cls.do_decode(x[1], item_type))
(
cls.do_decode(
json.loads(x[0]),
key_type,
return_none_for_unknown_union_types
),
cls.do_decode(
x[1], item_type, return_none_for_unknown_union_types
)
)
for x in obj.items()))

@classmethod
def decode_list(cls, obj, element_type):
# type: (List[Any], ConjureTypeType) -> List[Any]
def decode_list(
cls, obj, element_type, return_none_for_unknown_union_types=False
):
# type: (List[Any], ConjureTypeType, bool) -> List[Any]
"""Decodes json into a list, handling conversion of the elements.
Args:
obj: the json object to decode
element_type: a class object which is the conjure type of
the elements in this list.
return_none_for_unknown_union_types: if set to True, returns None
instead of raising an exception when an unknown union type is
encountered
Returns:
A python list where the elements are instances of type
element_type.
"""
if not isinstance(obj, list):
raise Exception("expected a python list")

return list(map(lambda x: cls.do_decode(x, element_type), obj))
return list(
map(
lambda x: cls.do_decode(
x, element_type, return_none_for_unknown_union_types
),
obj,
)
)

@classmethod
def decode_optional(cls, obj, object_type):
# type: (Optional[Any], ConjureTypeType) -> Optional[Any]
def decode_optional(
cls, obj, object_type, return_none_for_unknown_union_types=False
):
# type: (Optional[Any], ConjureTypeType, bool) -> Optional[Any]
"""Decodes json into an element, returning None if the provided object
is None.
Args:
obj: the json object to decode
object_type: a class object which is the conjure type of
the object if present.
return_none_for_unknown_union_types: if set to True, returns None
instead of raising an exception when an unknown union type is
encountered
Returns:
The decoded obj or None if no obj is provided.
"""
if obj is None:
return None

return cls.do_decode(obj, object_type)
return cls.do_decode(
obj, object_type, return_none_for_unknown_union_types
)

@classmethod
def decode_primitive(cls, obj, object_type):
Expand All @@ -225,45 +278,71 @@ def raise_mismatch():
return obj

@classmethod
def do_decode(cls, obj, obj_type):
# type: (Any, ConjureTypeType) -> Any
def do_decode(
cls, obj, obj_type, return_none_for_unknown_union_types=False
):
# type: (Any, ConjureTypeType, bool) -> Any
"""Decodes json into the specified type
Args:
obj: the json object to decode
element_type: a class object which is the type we're decoding into.
return_none_for_unknown_union_types: if set to True, returns None
instead of raising an exception when an unknown union type is
encountered
"""
if inspect.isclass(obj_type) and issubclass( # type: ignore
obj_type, ConjureBeanType
):
return cls.decode_conjure_bean_type(obj, obj_type) # type: ignore
return cls.decode_conjure_bean_type(
obj, obj_type, return_none_for_unknown_union_types
) # type: ignore

elif inspect.isclass(obj_type) and issubclass( # type: ignore
obj_type, ConjureUnionType
):
return cls.decode_conjure_union_type(obj, obj_type)
return cls.decode_conjure_union_type(
obj, obj_type, return_none_for_unknown_union_types
)

elif inspect.isclass(obj_type) and issubclass( # type: ignore
obj_type, ConjureEnumType
):
return cls.decode_conjure_enum_type(obj, obj_type)

elif isinstance(obj_type, DictType):
return cls.decode_dict(obj, obj_type.key_type, obj_type.value_type)
return cls.decode_dict(
obj,
obj_type.key_type,
obj_type.value_type,
return_none_for_unknown_union_types,
)

elif isinstance(obj_type, ListType):
return cls.decode_list(obj, obj_type.item_type)
return cls.decode_list(
obj, obj_type.item_type, return_none_for_unknown_union_types
)

elif isinstance(obj_type, OptionalType):
return cls.decode_optional(obj, obj_type.item_type)
return cls.decode_optional(
obj, obj_type.item_type, return_none_for_unknown_union_types
)

return cls.decode_primitive(obj, obj_type)

def decode(self, obj, obj_type):
# type: (Any, ConjureTypeType) -> Any
return self.do_decode(obj, obj_type)

def read_from_string(self, string_value, obj_type):
# type: (str, ConjureTypeType) -> Any
def decode(self, obj, obj_type, return_none_for_unknown_union_types=False):
# type: (Any, ConjureTypeType, bool) -> Any
return self.do_decode(
obj, obj_type, return_none_for_unknown_union_types)

def read_from_string(
self,
string_value,
obj_type,
return_none_for_unknown_union_types=False
):
# type: (str, ConjureTypeType, bool) -> Any
deserialized = json.loads(string_value)
return self.decode(deserialized, obj_type)
return self.decode(
deserialized, obj_type, return_none_for_unknown_union_types
)
34 changes: 34 additions & 0 deletions test/serde/test_decode_union.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# (c) Copyright 2023 Palantir Technologies Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
from conjure_python_client import ConjureDecoder
from test.generated.conjure_verification_types import Union


def test_union_with_unknown_type_fails():
with pytest.raises(ValueError) as e:
ConjureDecoder().read_from_string(
'{"type": "unknown", "unknown": "unknown_value"}', Union, False
)
assert e.match(
"unknown union type unknown for <class 'generated.conjure_verification_types.Union'>"
)


def test_union_with_unknown_type_and_return_none_for_unknown_types_succeeds():
decoded = ConjureDecoder().read_from_string(
'{"type": "unknown", "unknown": "unknown_value"}', Union, True
)
assert decoded is None

0 comments on commit ad533fa

Please sign in to comment.