Skip to content
115 changes: 82 additions & 33 deletions sdk/core/azure-core/azure/core/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,24 @@ def _is_readonly(p):
class AzureJSONEncoder(JSONEncoder):
"""A JSON encoder that's capable of serializing datetime objects and bytes."""

def __init__(self, *args, exclude_readonly: bool = False, exclude_none: bool = False, **kwargs):
super().__init__(*args, **kwargs)
Comment thread
tadelesh marked this conversation as resolved.
self.exclude_readonly = exclude_readonly
self.exclude_none = exclude_none

def default(self, o): # pylint: disable=too-many-return-statements
if _is_model(o):
readonly_props = [
p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)
]
return {k: v for k, v in o.items() if k not in readonly_props}
result = {k: v for k, v in o.items()}
if self.exclude_readonly:
readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)]
for k in readonly_props:
if k in result:
result.pop(k)
Comment thread
tadelesh marked this conversation as resolved.
Outdated
if self.exclude_none:
for k in list(result.keys()):
Comment thread
tadelesh marked this conversation as resolved.
Outdated
if result[k] is None or isinstance(result[k], _Null):
result.pop(k)
return result
if isinstance(o, (bytes, bytearray)):
return base64.b64encode(o).decode()
if isinstance(o, _Null):
Expand Down Expand Up @@ -313,11 +325,14 @@ def _get_model(module_name: str, model_name: str):
models = {
k: v
for k, v in sys.modules[module_name].__dict__.items()
if isinstance(v, type)
if isinstance(v, type) or isinstance(v, typing._GenericAlias)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could resolve type alias problem.

}
module_end = module_name.rsplit(".", 1)[0]
module = sys.modules[module_end]
models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)})
models.update({
k: v
for k, v in sys.modules[module_end].__dict__.items()
if isinstance(v, type) or isinstance(v, typing._GenericAlias)
})
if isinstance(model_name, str):
model_name = model_name.split(".")[-1]
if model_name not in models:
Expand Down Expand Up @@ -567,6 +582,37 @@ def _deserialize(cls, data):
return cls(data)
return mapped_cls._deserialize(data) # pylint: disable=protected-access

def as_dict(self, *, exclude_readonly: bool = False, exclude_none: bool = False) -> typing.Dict[str, typing.Any]:

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

json.dumps will return a string and need to loads again to get dict. I prefer to use more efficient way.

result = {}
if exclude_readonly:
readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)]
for k, v in self.items():
if exclude_readonly and k in readonly_props:
continue
if exclude_none and (v is None or isinstance(v, _Null)):
continue
result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly, exclude_none=exclude_none)
return result

@staticmethod
def _as_dict_value(v: typing.Any, exclude_readonly: bool = False, exclude_none: bool = False) -> typing.Any:
if v is None or isinstance(v, _Null):
return None
elif isinstance(v, (list, tuple, set)):
return [
Model._as_dict_value(x, exclude_readonly=exclude_readonly, exclude_none=exclude_none)
for x in v
]
elif isinstance(v, dict):
return {
dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly, exclude_none=exclude_none)
for dk, dv in v.items()
}
else:
return v.as_dict(exclude_readonly=exclude_readonly, exclude_none=exclude_none) \
if hasattr(v, "as_dict") else v



def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements
annotation: typing.Any,
Expand All @@ -576,8 +622,17 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur
if not annotation or annotation in [int, float]:
return None

# is it a forward ref / in quotes?
if isinstance(annotation, (str, typing.ForwardRef)):
try:
model_name = annotation.__forward_arg__ # type: ignore
except AttributeError:
model_name = annotation
if module is not None:
annotation = _get_model(module, model_name)
Comment on lines +635 to +642

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ForwardRef should be resolved at first.


try:
if module and _is_model(_get_model(module, annotation)):
if module and _is_model(annotation):
if rf:
rf._is_model = True

Expand Down Expand Up @@ -606,22 +661,8 @@ def _deserialize_model(
except AttributeError:
pass

if getattr(annotation, "__origin__", None) is typing.Union:

def _deserialize_with_union(union_annotation, obj):
for t in union_annotation.__args__:
try:
return _deserialize(t, obj, module, rf)
except DeserializationError:
pass
raise DeserializationError()

return functools.partial(_deserialize_with_union, annotation)

# is it optional?
try:
# right now, assuming we don't have unions, since we're getting rid of the only
# union we used to have in msrest models, which was union of str and enum
if any(a for a in annotation.__args__ if a == type(None)):
if_obj_deserializer = _get_deserialize_callable_from_annotation(
next(a for a in annotation.__args__ if a != type(None)), module, rf
Expand All @@ -638,14 +679,17 @@ def _deserialize_with_optional(
except AttributeError:
pass

# is it a forward ref / in quotes?
if isinstance(annotation, (str, typing.ForwardRef)):
try:
model_name = annotation.__forward_arg__ # type: ignore
except AttributeError:
model_name = annotation
if module is not None:
annotation = _get_model(module, model_name)
if getattr(annotation, "__origin__", None) is typing.Union:

def _deserialize_with_union(union_annotation, obj):
for t in union_annotation.__args__:
try:
return _deserialize(t, obj, module, rf)
except DeserializationError:
pass
raise DeserializationError()

return functools.partial(_deserialize_with_union, annotation)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Union should be handled after Optional, bc Optional is also a Union.


try:
if annotation._name == "Dict":
Expand Down Expand Up @@ -805,7 +849,9 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin
item = obj.get(self._rest_name)
if item is None:
return item
return _deserialize(self._type, _serialize(item, self._format), rf=self)
if self._is_model:
return item
return _deserialize(self._type, item, rf=self)

def __set__(self, obj: Model, value) -> None:
if value is None:
Expand All @@ -815,8 +861,11 @@ def __set__(self, obj: Model, value) -> None:
except KeyError:
pass
return
if self._is_model and not _is_model(value):
obj.__setitem__(self._rest_name, _deserialize(self._type, value))
if self._is_model:
if not _is_model(value):
value = _deserialize(self._type, value)
obj.__setitem__(self._rest_name, value)
return
obj.__setitem__(self._rest_name, _serialize(value, self._format))

def _get_deserialize_callable_from_annotation(
Expand Down
Loading