Skip to content
142 changes: 101 additions & 41 deletions sdk/core/azure-core/azure/core/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,17 @@ def _is_readonly(p):
class AzureJSONEncoder(JSONEncoder):
"""A JSON encoder that's capable of serializing datetime objects and bytes."""

def __init__(self, *args, exclude_readonly: bool = False, **kwargs):
super().__init__(*args, **kwargs)
self.exclude_readonly = exclude_readonly

def default(self, o): # pylint: disable=too-many-return-statements
if _is_model(o):
readonly_props = [
p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)
]
return {k: v for k, v in o.items() if k not in readonly_props}
if self.exclude_readonly:
readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)]
return {k: v for k, v in o.items() if k not in readonly_props}
else:
return dict(o.items())
if isinstance(o, (bytes, bytearray)):
return base64.b64encode(o).decode()
if isinstance(o, _Null):
Expand Down Expand Up @@ -309,15 +314,29 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] =
return _DESERIALIZE_MAPPING.get(annotation)


def _get_type_alias_type(module_name: str, alias_name: str):
types = {
k: v
for k, v in sys.modules[module_name].__dict__.items()
if isinstance(v, typing._GenericAlias) # type: ignore
}
if alias_name not in types:
return alias_name
return types[alias_name]


def _get_model(module_name: str, model_name: str):
models = {
k: v
for k, v in sys.modules[module_name].__dict__.items()
if isinstance(v, type)
}
module_end = module_name.rsplit(".", 1)[0]
module = sys.modules[module_end]
models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)})
models.update({
k: v
for k, v in sys.modules[module_end].__dict__.items()
if isinstance(v, type)
})
if isinstance(model_name, str):
model_name = model_name.split(".")[-1]
if model_name not in models:
Expand Down Expand Up @@ -547,25 +566,57 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None:
base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member

@classmethod
def _get_discriminator(cls) -> typing.Optional[str]:
def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]:
for v in cls.__dict__.values():
if (
isinstance(v, _RestField) and v._is_discriminator
): # pylint: disable=protected-access
if isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators: # pylint: disable=protected-access
return v._rest_name # pylint: disable=protected-access
return None

@classmethod
def _deserialize(cls, data):
def _deserialize(cls, data, exist_discriminators):
if not hasattr(cls, "__mapping__"): # pylint: disable=no-member
return cls(data)
discriminator = cls._get_discriminator()
discriminator = cls._get_discriminator(exist_discriminators)
exist_discriminators.append(discriminator)
mapped_cls = cls.__mapping__.get(
data.get(discriminator), cls
) # pylint: disable=no-member
if mapped_cls == cls:
return cls(data)
return mapped_cls._deserialize(data) # pylint: disable=protected-access
return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access

def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]:
"""Return a dict that can be JSONify using json.dump.

:keyword bool exclude_readonly: Whether to remove the readonly properties.
:returns: A dict JSON compatible object
:rtype: dict
"""

result = {}
if exclude_readonly:
readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)]
for k, v in self.items():
if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false
continue
result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly)
return result

@staticmethod
def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any:
if v is None or isinstance(v, _Null):
return None
if isinstance(v, (list, tuple, set)):
return [
Model._as_dict_value(x, exclude_readonly=exclude_readonly)
for x in v
]
if isinstance(v, dict):
return {
dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly)
for dk, dv in v.items()
}
return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v


def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements
Expand All @@ -576,8 +627,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur
if not annotation or annotation in [int, float]:
return None

# is it a type alias?
if isinstance(annotation, str):
if module is not None:
annotation = _get_type_alias_type(module, annotation)

# is it a forward ref / in quotes?
if isinstance(annotation, (str, typing.ForwardRef)):
try:
model_name = annotation.__forward_arg__ # type: ignore
except AttributeError:
model_name = annotation
if module is not None:
annotation = _get_model(module, model_name)
Comment on lines +635 to +642
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ForwardRef should be resolved at first.


try:
if module and _is_model(_get_model(module, annotation)):
if module and _is_model(annotation):
if rf:
rf._is_model = True

Expand All @@ -588,7 +653,7 @@ def _deserialize_model(
return obj
return _deserialize(model_deserializer, obj)

return functools.partial(_deserialize_model, _get_model(module, annotation))
return functools.partial(_deserialize_model, annotation)
except Exception:
pass

Expand All @@ -606,22 +671,8 @@ def _deserialize_model(
except AttributeError:
pass

if getattr(annotation, "__origin__", None) is typing.Union:

def _deserialize_with_union(union_annotation, obj):
for t in union_annotation.__args__:
try:
return _deserialize(t, obj, module, rf)
except DeserializationError:
pass
raise DeserializationError()

return functools.partial(_deserialize_with_union, annotation)

# is it optional?
try:
# right now, assuming we don't have unions, since we're getting rid of the only
# union we used to have in msrest models, which was union of str and enum
if any(a for a in annotation.__args__ if a == type(None)):
if_obj_deserializer = _get_deserialize_callable_from_annotation(
next(a for a in annotation.__args__ if a != type(None)), module, rf
Expand All @@ -638,14 +689,18 @@ def _deserialize_with_optional(
except AttributeError:
pass

# is it a forward ref / in quotes?
if isinstance(annotation, (str, typing.ForwardRef)):
try:
model_name = annotation.__forward_arg__ # type: ignore
except AttributeError:
model_name = annotation
if module is not None:
annotation = _get_model(module, model_name)
if getattr(annotation, "__origin__", None) is typing.Union:
deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__]

def _deserialize_with_union(deserializers, obj):
for deserializer in deserializers:
try:
return _deserialize(deserializer, obj)
except DeserializationError:
pass
raise DeserializationError()

return functools.partial(_deserialize_with_union, deserializers)

try:
if annotation._name == "Dict":
Expand Down Expand Up @@ -751,7 +806,7 @@ def _deserialize_with_callable(
# for unknown value, return raw value
return value
if isinstance(deserializer, type) and issubclass(deserializer, Model):
return deserializer._deserialize(value)
return deserializer._deserialize(value, [])
return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(
value
)
Expand Down Expand Up @@ -805,7 +860,9 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin
item = obj.get(self._rest_name)
if item is None:
return item
return _deserialize(self._type, _serialize(item, self._format), rf=self)
if self._is_model:
return item
return _deserialize(self._type, item, rf=self)

def __set__(self, obj: Model, value) -> None:
if value is None:
Expand All @@ -815,8 +872,11 @@ def __set__(self, obj: Model, value) -> None:
except KeyError:
pass
return
if self._is_model and not _is_model(value):
obj.__setitem__(self._rest_name, _deserialize(self._type, value))
if self._is_model:
if not _is_model(value):
value = _deserialize(self._type, value)
obj.__setitem__(self._rest_name, value)
return
obj.__setitem__(self._rest_name, _serialize(value, self._format))

def _get_deserialize_callable_from_annotation(
Expand Down
Loading