Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ def _serialize_body_parameter(self, builder: OperationType) -> List[str]:
elif self.code_model.options["models_mode"] == "dpg":
create_body_call = (
f"_{body_kwarg_name} = json.dumps({body_param.client_name}, "
"cls=AzureJSONEncoder) # type: ignore"
"cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore"
)
else:
create_body_call = f"_{body_kwarg_name} = {body_param.client_name}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,16 @@ def _is_readonly(p):
class AzureJSONEncoder(JSONEncoder):
"""A JSON encoder that's capable of serializing datetime objects and bytes."""

def __init__(self, *args, exclude_readonly: bool = False, **kwargs):
super().__init__(*args, **kwargs)
self.exclude_readonly = exclude_readonly

def default(self, o): # pylint: disable=too-many-return-statements
if _is_model(o):
readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)]
return {k: v for k, v in o.items() if k not in readonly_props}
if self.exclude_readonly:
readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)]
return {k: v for k, v in o.items() if k not in readonly_props}
return dict(o.items())
if isinstance(o, (bytes, bytearray)):
return base64.b64encode(o).decode()
if isinstance(o, _Null):
Expand Down Expand Up @@ -295,11 +301,29 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] =
return _DESERIALIZE_MAPPING.get(annotation)


def _get_type_alias_type(module_name: str, alias_name: str):
types = {
k: v
for k, v in sys.modules[module_name].__dict__.items()
if isinstance(v, typing._GenericAlias) # type: ignore
}
if alias_name not in types:
return alias_name
return types[alias_name]


def _get_model(module_name: str, model_name: str):
models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)}
models = {
k: v
for k, v in sys.modules[module_name].__dict__.items()
if isinstance(v, type)
}
module_end = module_name.rsplit(".", 1)[0]
module = sys.modules[module_end]
models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)})
models.update({
k: v
for k, v in sys.modules[module_end].__dict__.items()
if isinstance(v, type)
})
if isinstance(model_name, str):
model_name = model_name.split(".")[-1]
if model_name not in models:
Expand Down Expand Up @@ -461,7 +485,7 @@ class Model(_MyMutableMapping):
raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'")
dict_to_pass.update(
{
self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format)
self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v)
for k, v in kwargs.items()
if v is not None
}
Expand Down Expand Up @@ -499,33 +523,83 @@ class Model(_MyMutableMapping):
base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member

@classmethod
def _get_discriminator(cls) -> typing.Optional[str]:
def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]:
for v in cls.__dict__.values():
if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access
if isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators: # pylint: disable=protected-access
return v._rest_name # pylint: disable=protected-access
return None

@classmethod
def _deserialize(cls, data):
def _deserialize(cls, data, exist_discriminators):
if not hasattr(cls, "__mapping__"): # pylint: disable=no-member
return cls(data)
discriminator = cls._get_discriminator()
mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member
discriminator = cls._get_discriminator(exist_discriminators)
exist_discriminators.append(discriminator)
mapped_cls = cls.__mapping__.get(
data.get(discriminator), cls
) # pylint: disable=no-member
if mapped_cls == cls:
return cls(data)
return mapped_cls._deserialize(data) # pylint: disable=protected-access


def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements
return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access

def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]:
"""Return a dict that can be JSONify using json.dump.

:keyword bool exclude_readonly: Whether to remove the readonly properties.
:returns: A dict JSON compatible object
:rtype: dict
"""

result = {}
if exclude_readonly:
readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)]
for k, v in self.items():
if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false
continue
result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly)
return result

@staticmethod
def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any:
if v is None or isinstance(v, _Null):
return None
if isinstance(v, (list, tuple, set)):
return [
Model._as_dict_value(x, exclude_readonly=exclude_readonly)
for x in v
]
if isinstance(v, dict):
return {
dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly)
for dk, dv in v.items()
}
return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v


def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912
annotation: typing.Any,
module: typing.Optional[str],
rf: typing.Optional["_RestField"] = None,
) -> typing.Optional[typing.Callable[[typing.Any], typing.Any]]:
if not annotation or annotation in [int, float]:
return None

# is it a type alias?
if isinstance(annotation, str):
if module is not None:
annotation = _get_type_alias_type(module, annotation)

# is it a forward ref / in quotes?
if isinstance(annotation, (str, typing.ForwardRef)):
try:
model_name = annotation.__forward_arg__ # type: ignore
except AttributeError:
model_name = annotation
if module is not None:
annotation = _get_model(module, model_name)

try:
if module and _is_model(_get_model(module, annotation)):
if module and _is_model(annotation):
if rf:
rf._is_model = True

Expand All @@ -534,7 +608,7 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur
return obj
return _deserialize(model_deserializer, obj)

return functools.partial(_deserialize_model, _get_model(module, annotation))
return functools.partial(_deserialize_model, annotation)
except Exception:
pass

Expand All @@ -552,22 +626,8 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur
except AttributeError:
pass

if getattr(annotation, "__origin__", None) is typing.Union:

def _deserialize_with_union(union_annotation, obj):
for t in union_annotation.__args__:
try:
return _deserialize(t, obj, module, rf)
except DeserializationError:
pass
raise DeserializationError()

return functools.partial(_deserialize_with_union, annotation)

# is it optional?
try:
# right now, assuming we don't have unions, since we're getting rid of the only
# union we used to have in msrest models, which was union of str and enum
if any(a for a in annotation.__args__ if a == type(None)):
if_obj_deserializer = _get_deserialize_callable_from_annotation(
next(a for a in annotation.__args__ if a != type(None)), module, rf
Expand All @@ -582,14 +642,18 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur
except AttributeError:
pass

# is it a forward ref / in quotes?
if isinstance(annotation, (str, typing.ForwardRef)):
try:
model_name = annotation.__forward_arg__ # type: ignore
except AttributeError:
model_name = annotation
if module is not None:
annotation = _get_model(module, model_name)
if getattr(annotation, "__origin__", None) is typing.Union:
deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__]

def _deserialize_with_union(deserializers, obj):
for deserializer in deserializers:
try:
return _deserialize(deserializer, obj)
except DeserializationError:
pass
raise DeserializationError()

return functools.partial(_deserialize_with_union, deserializers)

try:
if annotation._name == "Dict":
Expand Down Expand Up @@ -680,7 +744,7 @@ def _deserialize_with_callable(
# for unknown value, return raw value
return value
if isinstance(deserializer, type) and issubclass(deserializer, Model):
return deserializer._deserialize(value)
return deserializer._deserialize(value, [])
return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value)
except Exception as e:
raise DeserializationError() from e
Expand Down Expand Up @@ -730,6 +794,8 @@ class _RestField:
item = obj.get(self._rest_name)
if item is None:
return item
if self._is_model:
return item
return _deserialize(self._type, _serialize(item, self._format), rf=self)

def __set__(self, obj: Model, value) -> None:
Expand All @@ -740,8 +806,11 @@ class _RestField:
except KeyError:
pass
return
if self._is_model and not _is_model(value):
obj.__setitem__(self._rest_name, _deserialize(self._type, value))
if self._is_model:
if not _is_model(value):
value = _deserialize(self._type, value)
obj.__setitem__(self._rest_name, value)
return
obj.__setitem__(self._rest_name, _serialize(value, self._format))

def _get_deserialize_callable_from_annotation(
Expand Down
Loading