diff --git a/packages/autorest.python/autorest/codegen/serializers/builder_serializer.py b/packages/autorest.python/autorest/codegen/serializers/builder_serializer.py index 094465775c9..94bea9b1d3d 100644 --- a/packages/autorest.python/autorest/codegen/serializers/builder_serializer.py +++ b/packages/autorest.python/autorest/codegen/serializers/builder_serializer.py @@ -751,7 +751,7 @@ def _serialize_body_parameter(self, builder: OperationType) -> List[str]: elif self.code_model.options["models_mode"] == "dpg": create_body_call = ( f"_{body_kwarg_name} = json.dumps({body_param.client_name}, " - "cls=AzureJSONEncoder) # type: ignore" + "cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore" ) else: create_body_call = f"_{body_kwarg_name} = {body_param.client_name}" diff --git a/packages/autorest.python/autorest/codegen/templates/model_base.py.jinja2 b/packages/autorest.python/autorest/codegen/templates/model_base.py.jinja2 index 48acb463cae..17c08a23b49 100644 --- a/packages/autorest.python/autorest/codegen/templates/model_base.py.jinja2 +++ b/packages/autorest.python/autorest/codegen/templates/model_base.py.jinja2 @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,29 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): - models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} + models = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, type) + } module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({ + k: v + for k, v in sys.modules[module_end].__dict__.items() + if isinstance(v, type) + }) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +485,7 @@ class Model(_MyMutableMapping): raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +523,60 @@ class Model(_MyMutableMapping): base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators: # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() - mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) + mapped_cls = cls.__mapping__.get( + data.get(discriminator), cls + ) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access - - -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [ + Model._as_dict_value(x, exclude_readonly=exclude_readonly) + for x in v + ] + if isinstance(v, dict): + return { + dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) + for dk, dv in v.items() + } + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v + + +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +584,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +608,7 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +626,8 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +642,18 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +744,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +794,8 @@ class _RestField: item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +806,11 @@ class _RestField: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/authentication-api-key/authentication/apikey/_model_base.py b/packages/typespec-python/test/generated/authentication-api-key/authentication/apikey/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/authentication-api-key/authentication/apikey/_model_base.py +++ b/packages/typespec-python/test/generated/authentication-api-key/authentication/apikey/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/authentication-http-custom/authentication/http/custom/_model_base.py b/packages/typespec-python/test/generated/authentication-http-custom/authentication/http/custom/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/authentication-http-custom/authentication/http/custom/_model_base.py +++ b/packages/typespec-python/test/generated/authentication-http-custom/authentication/http/custom/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/authentication-oauth2/authentication/oauth2/_model_base.py b/packages/typespec-python/test/generated/authentication-oauth2/authentication/oauth2/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/authentication-oauth2/authentication/oauth2/_model_base.py +++ b/packages/typespec-python/test/generated/authentication-oauth2/authentication/oauth2/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/authentication-union/authentication/union/_model_base.py b/packages/typespec-python/test/generated/authentication-union/authentication/union/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/authentication-union/authentication/union/_model_base.py +++ b/packages/typespec-python/test/generated/authentication-union/authentication/union/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/azure-client-generator-core-internal/_specs_/azure/clientgenerator/core/internal/_model_base.py b/packages/typespec-python/test/generated/azure-client-generator-core-internal/_specs_/azure/clientgenerator/core/internal/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/azure-client-generator-core-internal/_specs_/azure/clientgenerator/core/internal/_model_base.py +++ b/packages/typespec-python/test/generated/azure-client-generator-core-internal/_specs_/azure/clientgenerator/core/internal/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/azure-core-basic/_specs_/azure/core/basic/_model_base.py b/packages/typespec-python/test/generated/azure-core-basic/_specs_/azure/core/basic/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/azure-core-basic/_specs_/azure/core/basic/_model_base.py +++ b/packages/typespec-python/test/generated/azure-core-basic/_specs_/azure/core/basic/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/azure-core-basic/_specs_/azure/core/basic/_operations/_operations.py b/packages/typespec-python/test/generated/azure-core-basic/_specs_/azure/core/basic/_operations/_operations.py index 6dee023a3ac..a7e58e4c9f2 100644 --- a/packages/typespec-python/test/generated/azure-core-basic/_specs_/azure/core/basic/_operations/_operations.py +++ b/packages/typespec-python/test/generated/azure-core-basic/_specs_/azure/core/basic/_operations/_operations.py @@ -346,7 +346,7 @@ def create_or_update(self, id: int, resource: Union[_models.User, JSON, IO], **k if isinstance(resource, (IOBase, bytes)): _content = resource else: - _content = json.dumps(resource, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(resource, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_basic_create_or_update_request( id=id, @@ -492,7 +492,7 @@ def create_or_replace(self, id: int, resource: Union[_models.User, JSON, IO], ** if isinstance(resource, (IOBase, bytes)): _content = resource else: - _content = json.dumps(resource, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(resource, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_basic_create_or_replace_request( id=id, diff --git a/packages/typespec-python/test/generated/azure-core-basic/_specs_/azure/core/basic/aio/_operations/_operations.py b/packages/typespec-python/test/generated/azure-core-basic/_specs_/azure/core/basic/aio/_operations/_operations.py index b5b9993886a..19a5195a1d1 100644 --- a/packages/typespec-python/test/generated/azure-core-basic/_specs_/azure/core/basic/aio/_operations/_operations.py +++ b/packages/typespec-python/test/generated/azure-core-basic/_specs_/azure/core/basic/aio/_operations/_operations.py @@ -154,7 +154,7 @@ async def create_or_update(self, id: int, resource: Union[_models.User, JSON, IO if isinstance(resource, (IOBase, bytes)): _content = resource else: - _content = json.dumps(resource, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(resource, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_basic_create_or_update_request( id=id, @@ -300,7 +300,7 @@ async def create_or_replace(self, id: int, resource: Union[_models.User, JSON, I if isinstance(resource, (IOBase, bytes)): _content = resource else: - _content = json.dumps(resource, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(resource, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_basic_create_or_replace_request( id=id, diff --git a/packages/typespec-python/test/generated/azure-core-lro-standard/_specs_/azure/core/lro/standard/_model_base.py b/packages/typespec-python/test/generated/azure-core-lro-standard/_specs_/azure/core/lro/standard/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/azure-core-lro-standard/_specs_/azure/core/lro/standard/_model_base.py +++ b/packages/typespec-python/test/generated/azure-core-lro-standard/_specs_/azure/core/lro/standard/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/azure-core-lro-standard/_specs_/azure/core/lro/standard/_operations/_operations.py b/packages/typespec-python/test/generated/azure-core-lro-standard/_specs_/azure/core/lro/standard/_operations/_operations.py index ed274c4520d..3e09bf10eef 100644 --- a/packages/typespec-python/test/generated/azure-core-lro-standard/_specs_/azure/core/lro/standard/_operations/_operations.py +++ b/packages/typespec-python/test/generated/azure-core-lro-standard/_specs_/azure/core/lro/standard/_operations/_operations.py @@ -140,7 +140,7 @@ def _create_or_replace_initial(self, name: str, resource: Union[_models.User, JS if isinstance(resource, (IOBase, bytes)): _content = resource else: - _content = json.dumps(resource, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(resource, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_standard_create_or_replace_request( name=name, diff --git a/packages/typespec-python/test/generated/azure-core-lro-standard/_specs_/azure/core/lro/standard/aio/_operations/_operations.py b/packages/typespec-python/test/generated/azure-core-lro-standard/_specs_/azure/core/lro/standard/aio/_operations/_operations.py index faaf38333a2..04736e706ed 100644 --- a/packages/typespec-python/test/generated/azure-core-lro-standard/_specs_/azure/core/lro/standard/aio/_operations/_operations.py +++ b/packages/typespec-python/test/generated/azure-core-lro-standard/_specs_/azure/core/lro/standard/aio/_operations/_operations.py @@ -67,7 +67,7 @@ async def _create_or_replace_initial( if isinstance(resource, (IOBase, bytes)): _content = resource else: - _content = json.dumps(resource, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(resource, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_standard_create_or_replace_request( name=name, diff --git a/packages/typespec-python/test/generated/azure-core-traits/_specs_/azure/core/traits/_model_base.py b/packages/typespec-python/test/generated/azure-core-traits/_specs_/azure/core/traits/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/azure-core-traits/_specs_/azure/core/traits/_model_base.py +++ b/packages/typespec-python/test/generated/azure-core-traits/_specs_/azure/core/traits/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/azure-core-traits/_specs_/azure/core/traits/_operations/_operations.py b/packages/typespec-python/test/generated/azure-core-traits/_specs_/azure/core/traits/_operations/_operations.py index da6edc566ae..1c6e991abd5 100644 --- a/packages/typespec-python/test/generated/azure-core-traits/_specs_/azure/core/traits/_operations/_operations.py +++ b/packages/typespec-python/test/generated/azure-core-traits/_specs_/azure/core/traits/_operations/_operations.py @@ -316,7 +316,7 @@ def repeatable_action( if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_traits_repeatable_action_request( id=id, diff --git a/packages/typespec-python/test/generated/azure-core-traits/_specs_/azure/core/traits/aio/_operations/_operations.py b/packages/typespec-python/test/generated/azure-core-traits/_specs_/azure/core/traits/aio/_operations/_operations.py index fd95079cf63..1c156d000d6 100644 --- a/packages/typespec-python/test/generated/azure-core-traits/_specs_/azure/core/traits/aio/_operations/_operations.py +++ b/packages/typespec-python/test/generated/azure-core-traits/_specs_/azure/core/traits/aio/_operations/_operations.py @@ -237,7 +237,7 @@ async def repeatable_action( if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_traits_repeatable_action_request( id=id, diff --git a/packages/typespec-python/test/generated/azurecore-lro-rpc/azurecore/lro/rpc/_model_base.py b/packages/typespec-python/test/generated/azurecore-lro-rpc/azurecore/lro/rpc/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/azurecore-lro-rpc/azurecore/lro/rpc/_model_base.py +++ b/packages/typespec-python/test/generated/azurecore-lro-rpc/azurecore/lro/rpc/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/azurecore-lro-rpc/azurecore/lro/rpc/_operations/_operations.py b/packages/typespec-python/test/generated/azurecore-lro-rpc/azurecore/lro/rpc/_operations/_operations.py index 85c49df8429..dd1695657f2 100644 --- a/packages/typespec-python/test/generated/azurecore-lro-rpc/azurecore/lro/rpc/_operations/_operations.py +++ b/packages/typespec-python/test/generated/azurecore-lro-rpc/azurecore/lro/rpc/_operations/_operations.py @@ -86,7 +86,7 @@ def _long_running_rpc_initial(self, body: Union[_models.GenerationOptions, JSON, if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_rpc_long_running_rpc_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/azurecore-lro-rpc/azurecore/lro/rpc/aio/_operations/_operations.py b/packages/typespec-python/test/generated/azurecore-lro-rpc/azurecore/lro/rpc/aio/_operations/_operations.py index d32dc0f8d54..5df16c59f19 100644 --- a/packages/typespec-python/test/generated/azurecore-lro-rpc/azurecore/lro/rpc/aio/_operations/_operations.py +++ b/packages/typespec-python/test/generated/azurecore-lro-rpc/azurecore/lro/rpc/aio/_operations/_operations.py @@ -61,7 +61,7 @@ async def _long_running_rpc_initial(self, body: Union[_models.GenerationOptions, if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_rpc_long_running_rpc_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/azurecore-lro-rpclegacy/azurecore/lro/rpclegacy/_model_base.py b/packages/typespec-python/test/generated/azurecore-lro-rpclegacy/azurecore/lro/rpclegacy/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/azurecore-lro-rpclegacy/azurecore/lro/rpclegacy/_model_base.py +++ b/packages/typespec-python/test/generated/azurecore-lro-rpclegacy/azurecore/lro/rpclegacy/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/azurecore-lro-rpclegacy/azurecore/lro/rpclegacy/_operations/_operations.py b/packages/typespec-python/test/generated/azurecore-lro-rpclegacy/azurecore/lro/rpclegacy/_operations/_operations.py index 21738ee004a..c5116e018ad 100644 --- a/packages/typespec-python/test/generated/azurecore-lro-rpclegacy/azurecore/lro/rpclegacy/_operations/_operations.py +++ b/packages/typespec-python/test/generated/azurecore-lro-rpclegacy/azurecore/lro/rpclegacy/_operations/_operations.py @@ -86,7 +86,7 @@ def _create_job_initial(self, body: Union[_models.JobData, JSON, IO], **kwargs: if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_legacy_create_job_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/azurecore-lro-rpclegacy/azurecore/lro/rpclegacy/aio/_operations/_operations.py b/packages/typespec-python/test/generated/azurecore-lro-rpclegacy/azurecore/lro/rpclegacy/aio/_operations/_operations.py index 968eefbdbe8..2f25775f181 100644 --- a/packages/typespec-python/test/generated/azurecore-lro-rpclegacy/azurecore/lro/rpclegacy/aio/_operations/_operations.py +++ b/packages/typespec-python/test/generated/azurecore-lro-rpclegacy/azurecore/lro/rpclegacy/aio/_operations/_operations.py @@ -61,7 +61,7 @@ async def _create_job_initial(self, body: Union[_models.JobData, JSON, IO], **kw if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_legacy_create_job_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/client-structure-default/client/structure/service/_model_base.py b/packages/typespec-python/test/generated/client-structure-default/client/structure/service/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/client-structure-default/client/structure/service/_model_base.py +++ b/packages/typespec-python/test/generated/client-structure-default/client/structure/service/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/client-structure-multiclient/client/structure/multiclient/_model_base.py b/packages/typespec-python/test/generated/client-structure-multiclient/client/structure/multiclient/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/client-structure-multiclient/client/structure/multiclient/_model_base.py +++ b/packages/typespec-python/test/generated/client-structure-multiclient/client/structure/multiclient/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/client-structure-renamedoperation/client/structure/renamedoperation/_model_base.py b/packages/typespec-python/test/generated/client-structure-renamedoperation/client/structure/renamedoperation/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/client-structure-renamedoperation/client/structure/renamedoperation/_model_base.py +++ b/packages/typespec-python/test/generated/client-structure-renamedoperation/client/structure/renamedoperation/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/client-structure-twooperationgroup/client/structure/twooperationgroup/_model_base.py b/packages/typespec-python/test/generated/client-structure-twooperationgroup/client/structure/twooperationgroup/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/client-structure-twooperationgroup/client/structure/twooperationgroup/_model_base.py +++ b/packages/typespec-python/test/generated/client-structure-twooperationgroup/client/structure/twooperationgroup/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/encode-bytes/encode/bytes/_model_base.py b/packages/typespec-python/test/generated/encode-bytes/encode/bytes/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/encode-bytes/encode/bytes/_model_base.py +++ b/packages/typespec-python/test/generated/encode-bytes/encode/bytes/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/encode-bytes/encode/bytes/aio/operations/_operations.py b/packages/typespec-python/test/generated/encode-bytes/encode/bytes/aio/operations/_operations.py index bbaa55d6b6b..03d2e4587b0 100644 --- a/packages/typespec-python/test/generated/encode-bytes/encode/bytes/aio/operations/_operations.py +++ b/packages/typespec-python/test/generated/encode-bytes/encode/bytes/aio/operations/_operations.py @@ -369,7 +369,7 @@ async def default( if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_property_default_request( content_type=content_type, @@ -492,7 +492,7 @@ async def base64( if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_property_base64_request( content_type=content_type, @@ -615,7 +615,7 @@ async def base64url( if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_property_base64url_request( content_type=content_type, @@ -742,7 +742,7 @@ async def base64url_array( if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_property_base64url_array_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/encode-bytes/encode/bytes/operations/_operations.py b/packages/typespec-python/test/generated/encode-bytes/encode/bytes/operations/_operations.py index 993eb003a0f..e90bc1975ee 100644 --- a/packages/typespec-python/test/generated/encode-bytes/encode/bytes/operations/_operations.py +++ b/packages/typespec-python/test/generated/encode-bytes/encode/bytes/operations/_operations.py @@ -523,7 +523,7 @@ def default( if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_property_default_request( content_type=content_type, @@ -642,7 +642,7 @@ def base64(self, body: Union[_models.Base64BytesProperty, JSON, IO], **kwargs: A if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_property_base64_request( content_type=content_type, @@ -765,7 +765,7 @@ def base64url( if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_property_base64url_request( content_type=content_type, @@ -892,7 +892,7 @@ def base64url_array( if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_property_base64url_array_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/encode-datetime/encode/datetime/_model_base.py b/packages/typespec-python/test/generated/encode-datetime/encode/datetime/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/encode-datetime/encode/datetime/_model_base.py +++ b/packages/typespec-python/test/generated/encode-datetime/encode/datetime/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/encode-datetime/encode/datetime/aio/operations/_operations.py b/packages/typespec-python/test/generated/encode-datetime/encode/datetime/aio/operations/_operations.py index f2b1214a3c9..5818d66891b 100644 --- a/packages/typespec-python/test/generated/encode-datetime/encode/datetime/aio/operations/_operations.py +++ b/packages/typespec-python/test/generated/encode-datetime/encode/datetime/aio/operations/_operations.py @@ -429,7 +429,7 @@ async def default( if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_property_default_request( content_type=content_type, @@ -552,7 +552,7 @@ async def rfc3339( if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_property_rfc3339_request( content_type=content_type, @@ -675,7 +675,7 @@ async def rfc7231( if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_property_rfc7231_request( content_type=content_type, @@ -802,7 +802,7 @@ async def unix_timestamp( if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_property_unix_timestamp_request( content_type=content_type, @@ -930,7 +930,7 @@ async def unix_timestamp_array( if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_property_unix_timestamp_array_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/encode-datetime/encode/datetime/operations/_operations.py b/packages/typespec-python/test/generated/encode-datetime/encode/datetime/operations/_operations.py index cce12e072ef..4a4904efb3d 100644 --- a/packages/typespec-python/test/generated/encode-datetime/encode/datetime/operations/_operations.py +++ b/packages/typespec-python/test/generated/encode-datetime/encode/datetime/operations/_operations.py @@ -623,7 +623,7 @@ def default( if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_property_default_request( content_type=content_type, @@ -746,7 +746,7 @@ def rfc3339( if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_property_rfc3339_request( content_type=content_type, @@ -869,7 +869,7 @@ def rfc7231( if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_property_rfc7231_request( content_type=content_type, @@ -996,7 +996,7 @@ def unix_timestamp( if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_property_unix_timestamp_request( content_type=content_type, @@ -1124,7 +1124,7 @@ def unix_timestamp_array( if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_property_unix_timestamp_array_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/encode-duration/encode/duration/_model_base.py b/packages/typespec-python/test/generated/encode-duration/encode/duration/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/encode-duration/encode/duration/_model_base.py +++ b/packages/typespec-python/test/generated/encode-duration/encode/duration/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/encode-duration/encode/duration/aio/operations/_operations.py b/packages/typespec-python/test/generated/encode-duration/encode/duration/aio/operations/_operations.py index 7511e3bad0c..e8516e9a674 100644 --- a/packages/typespec-python/test/generated/encode-duration/encode/duration/aio/operations/_operations.py +++ b/packages/typespec-python/test/generated/encode-duration/encode/duration/aio/operations/_operations.py @@ -429,7 +429,7 @@ async def default( if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_property_default_request( content_type=content_type, @@ -552,7 +552,7 @@ async def iso8601( if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_property_iso8601_request( content_type=content_type, @@ -679,7 +679,7 @@ async def int32_seconds( if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_property_int32_seconds_request( content_type=content_type, @@ -806,7 +806,7 @@ async def float_seconds( if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_property_float_seconds_request( content_type=content_type, @@ -934,7 +934,7 @@ async def float_seconds_array( if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_property_float_seconds_array_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/encode-duration/encode/duration/operations/_operations.py b/packages/typespec-python/test/generated/encode-duration/encode/duration/operations/_operations.py index 4c4a980ff38..0b8c6e79234 100644 --- a/packages/typespec-python/test/generated/encode-duration/encode/duration/operations/_operations.py +++ b/packages/typespec-python/test/generated/encode-duration/encode/duration/operations/_operations.py @@ -617,7 +617,7 @@ def default( if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_property_default_request( content_type=content_type, @@ -740,7 +740,7 @@ def iso8601( if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_property_iso8601_request( content_type=content_type, @@ -867,7 +867,7 @@ def int32_seconds( if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_property_int32_seconds_request( content_type=content_type, @@ -994,7 +994,7 @@ def float_seconds( if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_property_float_seconds_request( content_type=content_type, @@ -1122,7 +1122,7 @@ def float_seconds_array( if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_property_float_seconds_array_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/headasbooleanfalse/headasbooleanfalse/_model_base.py b/packages/typespec-python/test/generated/headasbooleanfalse/headasbooleanfalse/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/headasbooleanfalse/headasbooleanfalse/_model_base.py +++ b/packages/typespec-python/test/generated/headasbooleanfalse/headasbooleanfalse/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/headasbooleanfalse/headasbooleanfalse/_operations/_operations.py b/packages/typespec-python/test/generated/headasbooleanfalse/headasbooleanfalse/_operations/_operations.py index 25a41dff96d..d7e6ce11e43 100644 --- a/packages/typespec-python/test/generated/headasbooleanfalse/headasbooleanfalse/_operations/_operations.py +++ b/packages/typespec-python/test/generated/headasbooleanfalse/headasbooleanfalse/_operations/_operations.py @@ -215,7 +215,7 @@ def get_model(self, input: Union[_models.VisibilityModel, JSON, IO], **kwargs: A if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_visibility_get_model_request( content_type=content_type, @@ -338,7 +338,7 @@ def head_model( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_visibility_head_model_request( content_type=content_type, @@ -454,7 +454,7 @@ def put_model( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_visibility_put_model_request( content_type=content_type, @@ -570,7 +570,7 @@ def patch_model( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_visibility_patch_model_request( content_type=content_type, @@ -686,7 +686,7 @@ def post_model( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_visibility_post_model_request( content_type=content_type, @@ -802,7 +802,7 @@ def delete_model( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_visibility_delete_model_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/headasbooleanfalse/headasbooleanfalse/aio/_operations/_operations.py b/packages/typespec-python/test/generated/headasbooleanfalse/headasbooleanfalse/aio/_operations/_operations.py index 887e784e4a7..9ccd7fc3bc2 100644 --- a/packages/typespec-python/test/generated/headasbooleanfalse/headasbooleanfalse/aio/_operations/_operations.py +++ b/packages/typespec-python/test/generated/headasbooleanfalse/headasbooleanfalse/aio/_operations/_operations.py @@ -136,7 +136,7 @@ async def get_model( if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_visibility_get_model_request( content_type=content_type, @@ -259,7 +259,7 @@ async def head_model( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_visibility_head_model_request( content_type=content_type, @@ -375,7 +375,7 @@ async def put_model( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_visibility_put_model_request( content_type=content_type, @@ -491,7 +491,7 @@ async def patch_model( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_visibility_patch_model_request( content_type=content_type, @@ -607,7 +607,7 @@ async def post_model( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_visibility_post_model_request( content_type=content_type, @@ -723,7 +723,7 @@ async def delete_model( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_visibility_delete_model_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/headasbooleantrue/headasbooleantrue/_model_base.py b/packages/typespec-python/test/generated/headasbooleantrue/headasbooleantrue/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/headasbooleantrue/headasbooleantrue/_model_base.py +++ b/packages/typespec-python/test/generated/headasbooleantrue/headasbooleantrue/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/headasbooleantrue/headasbooleantrue/_operations/_operations.py b/packages/typespec-python/test/generated/headasbooleantrue/headasbooleantrue/_operations/_operations.py index 203c46643e1..c2b3984522a 100644 --- a/packages/typespec-python/test/generated/headasbooleantrue/headasbooleantrue/_operations/_operations.py +++ b/packages/typespec-python/test/generated/headasbooleantrue/headasbooleantrue/_operations/_operations.py @@ -215,7 +215,7 @@ def get_model(self, input: Union[_models.VisibilityModel, JSON, IO], **kwargs: A if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_visibility_get_model_request( content_type=content_type, @@ -332,7 +332,7 @@ def head_model(self, input: Union[_models.VisibilityModel, JSON, IO], **kwargs: if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_visibility_head_model_request( content_type=content_type, @@ -449,7 +449,7 @@ def put_model( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_visibility_put_model_request( content_type=content_type, @@ -565,7 +565,7 @@ def patch_model( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_visibility_patch_model_request( content_type=content_type, @@ -681,7 +681,7 @@ def post_model( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_visibility_post_model_request( content_type=content_type, @@ -797,7 +797,7 @@ def delete_model( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_visibility_delete_model_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/headasbooleantrue/headasbooleantrue/aio/_operations/_operations.py b/packages/typespec-python/test/generated/headasbooleantrue/headasbooleantrue/aio/_operations/_operations.py index 4cea4aea0c8..9a42a9a34ab 100644 --- a/packages/typespec-python/test/generated/headasbooleantrue/headasbooleantrue/aio/_operations/_operations.py +++ b/packages/typespec-python/test/generated/headasbooleantrue/headasbooleantrue/aio/_operations/_operations.py @@ -136,7 +136,7 @@ async def get_model( if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_visibility_get_model_request( content_type=content_type, @@ -253,7 +253,7 @@ async def head_model(self, input: Union[_models.VisibilityModel, JSON, IO], **kw if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_visibility_head_model_request( content_type=content_type, @@ -370,7 +370,7 @@ async def put_model( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_visibility_put_model_request( content_type=content_type, @@ -486,7 +486,7 @@ async def patch_model( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_visibility_patch_model_request( content_type=content_type, @@ -602,7 +602,7 @@ async def post_model( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_visibility_post_model_request( content_type=content_type, @@ -718,7 +718,7 @@ async def delete_model( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_visibility_delete_model_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/parameters-body-optionality/parameters/bodyoptionality/_model_base.py b/packages/typespec-python/test/generated/parameters-body-optionality/parameters/bodyoptionality/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/parameters-body-optionality/parameters/bodyoptionality/_model_base.py +++ b/packages/typespec-python/test/generated/parameters-body-optionality/parameters/bodyoptionality/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/parameters-body-optionality/parameters/bodyoptionality/aio/operations/_operations.py b/packages/typespec-python/test/generated/parameters-body-optionality/parameters/bodyoptionality/aio/operations/_operations.py index 9ee5cbde8f7..81d8a481cba 100644 --- a/packages/typespec-python/test/generated/parameters-body-optionality/parameters/bodyoptionality/aio/operations/_operations.py +++ b/packages/typespec-python/test/generated/parameters-body-optionality/parameters/bodyoptionality/aio/operations/_operations.py @@ -151,7 +151,7 @@ async def set( # pylint: disable=inconsistent-return-statements _content = body else: if body is not None: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore else: _content = None @@ -270,7 +270,7 @@ async def omit( # pylint: disable=inconsistent-return-statements _content = body else: if body is not None: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore else: _content = None @@ -390,7 +390,7 @@ async def required_explicit( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_body_optionality_required_explicit_request( content_type=content_type, @@ -506,7 +506,7 @@ async def required_implicit( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_body_optionality_required_implicit_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/parameters-body-optionality/parameters/bodyoptionality/operations/_operations.py b/packages/typespec-python/test/generated/parameters-body-optionality/parameters/bodyoptionality/operations/_operations.py index 94202da3399..040e23276dd 100644 --- a/packages/typespec-python/test/generated/parameters-body-optionality/parameters/bodyoptionality/operations/_operations.py +++ b/packages/typespec-python/test/generated/parameters-body-optionality/parameters/bodyoptionality/operations/_operations.py @@ -205,7 +205,7 @@ def set( # pylint: disable=inconsistent-return-statements _content = body else: if body is not None: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore else: _content = None @@ -324,7 +324,7 @@ def omit( # pylint: disable=inconsistent-return-statements _content = body else: if body is not None: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore else: _content = None @@ -444,7 +444,7 @@ def required_explicit( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_body_optionality_required_explicit_request( content_type=content_type, @@ -560,7 +560,7 @@ def required_implicit( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_body_optionality_required_implicit_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/parameters-collection-format/parameters/collectionformat/_model_base.py b/packages/typespec-python/test/generated/parameters-collection-format/parameters/collectionformat/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/parameters-collection-format/parameters/collectionformat/_model_base.py +++ b/packages/typespec-python/test/generated/parameters-collection-format/parameters/collectionformat/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/parameters-spread/parameters/spread/_model_base.py b/packages/typespec-python/test/generated/parameters-spread/parameters/spread/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/parameters-spread/parameters/spread/_model_base.py +++ b/packages/typespec-python/test/generated/parameters-spread/parameters/spread/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/parameters-spread/parameters/spread/aio/operations/_operations.py b/packages/typespec-python/test/generated/parameters-spread/parameters/spread/aio/operations/_operations.py index ce5987dcf18..439586aaaa0 100644 --- a/packages/typespec-python/test/generated/parameters-spread/parameters/spread/aio/operations/_operations.py +++ b/packages/typespec-python/test/generated/parameters-spread/parameters/spread/aio/operations/_operations.py @@ -150,7 +150,7 @@ async def spread_as_request_body( # pylint: disable=inconsistent-return-stateme if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_model_spread_as_request_body_request( content_type=content_type, @@ -307,7 +307,7 @@ async def spread_as_request_body( # pylint: disable=inconsistent-return-stateme if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_alias_spread_as_request_body_request( content_type=content_type, @@ -462,7 +462,7 @@ async def spread_as_request_parameter( # pylint: disable=inconsistent-return-st if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_alias_spread_as_request_parameter_request( id=id, @@ -681,7 +681,7 @@ async def spread_with_multiple_parameters( # pylint: disable=inconsistent-retur if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_alias_spread_with_multiple_parameters_request( id=id, diff --git a/packages/typespec-python/test/generated/parameters-spread/parameters/spread/operations/_operations.py b/packages/typespec-python/test/generated/parameters-spread/parameters/spread/operations/_operations.py index deb1f1d758e..c7ddaecec07 100644 --- a/packages/typespec-python/test/generated/parameters-spread/parameters/spread/operations/_operations.py +++ b/packages/typespec-python/test/generated/parameters-spread/parameters/spread/operations/_operations.py @@ -220,7 +220,7 @@ def spread_as_request_body( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_model_spread_as_request_body_request( content_type=content_type, @@ -377,7 +377,7 @@ def spread_as_request_body( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_alias_spread_as_request_body_request( content_type=content_type, @@ -532,7 +532,7 @@ def spread_as_request_parameter( # pylint: disable=inconsistent-return-statemen if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_alias_spread_as_request_parameter_request( id=id, @@ -751,7 +751,7 @@ def spread_with_multiple_parameters( # pylint: disable=inconsistent-return-stat if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_alias_spread_with_multiple_parameters_request( id=id, diff --git a/packages/typespec-python/test/generated/payload-content-negotiation/payload/contentnegotiation/_model_base.py b/packages/typespec-python/test/generated/payload-content-negotiation/payload/contentnegotiation/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/payload-content-negotiation/payload/contentnegotiation/_model_base.py +++ b/packages/typespec-python/test/generated/payload-content-negotiation/payload/contentnegotiation/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/projection-projected-name/projection/projectedname/_model_base.py b/packages/typespec-python/test/generated/projection-projected-name/projection/projectedname/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/projection-projected-name/projection/projectedname/_model_base.py +++ b/packages/typespec-python/test/generated/projection-projected-name/projection/projectedname/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/projection-projected-name/projection/projectedname/aio/operations/_operations.py b/packages/typespec-python/test/generated/projection-projected-name/projection/projectedname/aio/operations/_operations.py index 7bcda325260..6e38a2da353 100644 --- a/packages/typespec-python/test/generated/projection-projected-name/projection/projectedname/aio/operations/_operations.py +++ b/packages/typespec-python/test/generated/projection-projected-name/projection/projectedname/aio/operations/_operations.py @@ -152,7 +152,7 @@ async def json( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_property_json_request( content_type=content_type, @@ -268,7 +268,7 @@ async def client( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_property_client_request( content_type=content_type, @@ -384,7 +384,7 @@ async def language( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_property_language_request( content_type=content_type, @@ -500,7 +500,7 @@ async def json_and_client( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_property_json_and_client_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/projection-projected-name/projection/projectedname/operations/_operations.py b/packages/typespec-python/test/generated/projection-projected-name/projection/projectedname/operations/_operations.py index 0159cc252a7..53c6648f6c6 100644 --- a/packages/typespec-python/test/generated/projection-projected-name/projection/projectedname/operations/_operations.py +++ b/packages/typespec-python/test/generated/projection-projected-name/projection/projectedname/operations/_operations.py @@ -223,7 +223,7 @@ def json( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_property_json_request( content_type=content_type, @@ -339,7 +339,7 @@ def client( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_property_client_request( content_type=content_type, @@ -455,7 +455,7 @@ def language( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_property_language_request( content_type=content_type, @@ -571,7 +571,7 @@ def json_and_client( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_property_json_and_client_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/resiliency-srv-driven1/resiliency/srv/driven1/_model_base.py b/packages/typespec-python/test/generated/resiliency-srv-driven1/resiliency/srv/driven1/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/resiliency-srv-driven1/resiliency/srv/driven1/_model_base.py +++ b/packages/typespec-python/test/generated/resiliency-srv-driven1/resiliency/srv/driven1/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/resiliency-srv-driven2/resiliency/srv/driven2/_model_base.py b/packages/typespec-python/test/generated/resiliency-srv-driven2/resiliency/srv/driven2/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/resiliency-srv-driven2/resiliency/srv/driven2/_model_base.py +++ b/packages/typespec-python/test/generated/resiliency-srv-driven2/resiliency/srv/driven2/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/server-path-multiple/server/path/multiple/_model_base.py b/packages/typespec-python/test/generated/server-path-multiple/server/path/multiple/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/server-path-multiple/server/path/multiple/_model_base.py +++ b/packages/typespec-python/test/generated/server-path-multiple/server/path/multiple/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/server-path-single/server/path/single/_model_base.py b/packages/typespec-python/test/generated/server-path-single/server/path/single/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/server-path-single/server/path/single/_model_base.py +++ b/packages/typespec-python/test/generated/server-path-single/server/path/single/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/special-headers-client-request-id/specialheaders/clientrequestid/_model_base.py b/packages/typespec-python/test/generated/special-headers-client-request-id/specialheaders/clientrequestid/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/special-headers-client-request-id/specialheaders/clientrequestid/_model_base.py +++ b/packages/typespec-python/test/generated/special-headers-client-request-id/specialheaders/clientrequestid/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/special-headers-repeatability/specialheaders/repeatability/_model_base.py b/packages/typespec-python/test/generated/special-headers-repeatability/specialheaders/repeatability/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/special-headers-repeatability/specialheaders/repeatability/_model_base.py +++ b/packages/typespec-python/test/generated/special-headers-repeatability/specialheaders/repeatability/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/special-words/specialwords/_model_base.py b/packages/typespec-python/test/generated/special-words/specialwords/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/special-words/specialwords/_model_base.py +++ b/packages/typespec-python/test/generated/special-words/specialwords/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/special-words/specialwords/aio/operations/_operations.py b/packages/typespec-python/test/generated/special-words/specialwords/aio/operations/_operations.py index 3aa880f8c3d..3276e3cfcc6 100644 --- a/packages/typespec-python/test/generated/special-words/specialwords/aio/operations/_operations.py +++ b/packages/typespec-python/test/generated/special-words/specialwords/aio/operations/_operations.py @@ -383,7 +383,7 @@ async def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_model_put_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/special-words/specialwords/operations/_operations.py b/packages/typespec-python/test/generated/special-words/specialwords/operations/_operations.py index edc1c77a2f1..549e3465d76 100644 --- a/packages/typespec-python/test/generated/special-words/specialwords/operations/_operations.py +++ b/packages/typespec-python/test/generated/special-words/specialwords/operations/_operations.py @@ -437,7 +437,7 @@ def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_model_put_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/typetest-array/typetest/array/_model_base.py b/packages/typespec-python/test/generated/typetest-array/typetest/array/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/typetest-array/typetest/array/_model_base.py +++ b/packages/typespec-python/test/generated/typetest-array/typetest/array/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/typetest-array/typetest/array/aio/operations/_operations.py b/packages/typespec-python/test/generated/typetest-array/typetest/array/aio/operations/_operations.py index bd688bf52dc..861fc928a28 100644 --- a/packages/typespec-python/test/generated/typetest-array/typetest/array/aio/operations/_operations.py +++ b/packages/typespec-python/test/generated/typetest-array/typetest/array/aio/operations/_operations.py @@ -194,7 +194,7 @@ async def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_int32_value_put_request( content_type=content_type, @@ -362,7 +362,7 @@ async def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_int64_value_put_request( content_type=content_type, @@ -530,7 +530,7 @@ async def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_boolean_value_put_request( content_type=content_type, @@ -698,7 +698,7 @@ async def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_string_value_put_request( content_type=content_type, @@ -866,7 +866,7 @@ async def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_float32_value_put_request( content_type=content_type, @@ -1034,7 +1034,7 @@ async def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_datetime_value_put_request( content_type=content_type, @@ -1202,7 +1202,7 @@ async def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_duration_value_put_request( content_type=content_type, @@ -1370,7 +1370,7 @@ async def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_unknown_value_put_request( content_type=content_type, @@ -1538,7 +1538,7 @@ async def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_model_value_put_request( content_type=content_type, @@ -1706,7 +1706,7 @@ async def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_nullable_float_value_put_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/typetest-array/typetest/array/operations/_operations.py b/packages/typespec-python/test/generated/typetest-array/typetest/array/operations/_operations.py index 554b94a51f3..776b2b6a829 100644 --- a/packages/typespec-python/test/generated/typetest-array/typetest/array/operations/_operations.py +++ b/packages/typespec-python/test/generated/typetest-array/typetest/array/operations/_operations.py @@ -454,7 +454,7 @@ def put(self, body: Union[List[int], IO], **kwargs: Any) -> None: # pylint: dis if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_int32_value_put_request( content_type=content_type, @@ -620,7 +620,7 @@ def put(self, body: Union[List[int], IO], **kwargs: Any) -> None: # pylint: dis if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_int64_value_put_request( content_type=content_type, @@ -786,7 +786,7 @@ def put(self, body: Union[List[bool], IO], **kwargs: Any) -> None: # pylint: di if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_boolean_value_put_request( content_type=content_type, @@ -952,7 +952,7 @@ def put(self, body: Union[List[str], IO], **kwargs: Any) -> None: # pylint: dis if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_string_value_put_request( content_type=content_type, @@ -1120,7 +1120,7 @@ def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_float32_value_put_request( content_type=content_type, @@ -1288,7 +1288,7 @@ def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_datetime_value_put_request( content_type=content_type, @@ -1456,7 +1456,7 @@ def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_duration_value_put_request( content_type=content_type, @@ -1622,7 +1622,7 @@ def put(self, body: Union[List[Any], IO], **kwargs: Any) -> None: # pylint: dis if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_unknown_value_put_request( content_type=content_type, @@ -1790,7 +1790,7 @@ def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_model_value_put_request( content_type=content_type, @@ -1958,7 +1958,7 @@ def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_nullable_float_value_put_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/typetest-dictionary/typetest/dictionary/_model_base.py b/packages/typespec-python/test/generated/typetest-dictionary/typetest/dictionary/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/typetest-dictionary/typetest/dictionary/_model_base.py +++ b/packages/typespec-python/test/generated/typetest-dictionary/typetest/dictionary/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/typetest-dictionary/typetest/dictionary/aio/operations/_operations.py b/packages/typespec-python/test/generated/typetest-dictionary/typetest/dictionary/aio/operations/_operations.py index 8b9d2841345..33ff6d9d4a3 100644 --- a/packages/typespec-python/test/generated/typetest-dictionary/typetest/dictionary/aio/operations/_operations.py +++ b/packages/typespec-python/test/generated/typetest-dictionary/typetest/dictionary/aio/operations/_operations.py @@ -196,7 +196,7 @@ async def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_int32_value_put_request( content_type=content_type, @@ -364,7 +364,7 @@ async def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_int64_value_put_request( content_type=content_type, @@ -532,7 +532,7 @@ async def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_boolean_value_put_request( content_type=content_type, @@ -700,7 +700,7 @@ async def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_string_value_put_request( content_type=content_type, @@ -868,7 +868,7 @@ async def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_float32_value_put_request( content_type=content_type, @@ -1036,7 +1036,7 @@ async def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_datetime_value_put_request( content_type=content_type, @@ -1204,7 +1204,7 @@ async def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_duration_value_put_request( content_type=content_type, @@ -1372,7 +1372,7 @@ async def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_unknown_value_put_request( content_type=content_type, @@ -1540,7 +1540,7 @@ async def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_model_value_put_request( content_type=content_type, @@ -1708,7 +1708,7 @@ async def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_recursive_model_value_put_request( content_type=content_type, @@ -1876,7 +1876,7 @@ async def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_nullable_float_value_put_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/typetest-dictionary/typetest/dictionary/operations/_operations.py b/packages/typespec-python/test/generated/typetest-dictionary/typetest/dictionary/operations/_operations.py index 7a06ab704fe..b9dd4e5c0f2 100644 --- a/packages/typespec-python/test/generated/typetest-dictionary/typetest/dictionary/operations/_operations.py +++ b/packages/typespec-python/test/generated/typetest-dictionary/typetest/dictionary/operations/_operations.py @@ -484,7 +484,7 @@ def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_int32_value_put_request( content_type=content_type, @@ -652,7 +652,7 @@ def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_int64_value_put_request( content_type=content_type, @@ -820,7 +820,7 @@ def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_boolean_value_put_request( content_type=content_type, @@ -988,7 +988,7 @@ def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_string_value_put_request( content_type=content_type, @@ -1156,7 +1156,7 @@ def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_float32_value_put_request( content_type=content_type, @@ -1324,7 +1324,7 @@ def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_datetime_value_put_request( content_type=content_type, @@ -1492,7 +1492,7 @@ def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_duration_value_put_request( content_type=content_type, @@ -1660,7 +1660,7 @@ def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_unknown_value_put_request( content_type=content_type, @@ -1828,7 +1828,7 @@ def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_model_value_put_request( content_type=content_type, @@ -1996,7 +1996,7 @@ def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_recursive_model_value_put_request( content_type=content_type, @@ -2164,7 +2164,7 @@ def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_nullable_float_value_put_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/typetest-enum-extensible/typetest/enum/extensible/_model_base.py b/packages/typespec-python/test/generated/typetest-enum-extensible/typetest/enum/extensible/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/typetest-enum-extensible/typetest/enum/extensible/_model_base.py +++ b/packages/typespec-python/test/generated/typetest-enum-extensible/typetest/enum/extensible/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/typetest-enum-extensible/typetest/enum/extensible/_operations/_operations.py b/packages/typespec-python/test/generated/typetest-enum-extensible/typetest/enum/extensible/_operations/_operations.py index 37e4afe97f1..a7c4308dc50 100644 --- a/packages/typespec-python/test/generated/typetest-enum-extensible/typetest/enum/extensible/_operations/_operations.py +++ b/packages/typespec-python/test/generated/typetest-enum-extensible/typetest/enum/extensible/_operations/_operations.py @@ -227,7 +227,7 @@ def put_known_value( # pylint: disable=inconsistent-return-statements content_type: str = kwargs.pop("content_type", _headers.pop("Content-Type", "application/json")) cls: ClsType[None] = kwargs.pop("cls", None) - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_extensible_put_known_value_request( content_type=content_type, @@ -285,7 +285,7 @@ def put_unknown_value( # pylint: disable=inconsistent-return-statements content_type: str = kwargs.pop("content_type", _headers.pop("Content-Type", "application/json")) cls: ClsType[None] = kwargs.pop("cls", None) - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_extensible_put_unknown_value_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/typetest-enum-extensible/typetest/enum/extensible/aio/_operations/_operations.py b/packages/typespec-python/test/generated/typetest-enum-extensible/typetest/enum/extensible/aio/_operations/_operations.py index b02faa4de53..433c66a542d 100644 --- a/packages/typespec-python/test/generated/typetest-enum-extensible/typetest/enum/extensible/aio/_operations/_operations.py +++ b/packages/typespec-python/test/generated/typetest-enum-extensible/typetest/enum/extensible/aio/_operations/_operations.py @@ -173,7 +173,7 @@ async def put_known_value( # pylint: disable=inconsistent-return-statements content_type: str = kwargs.pop("content_type", _headers.pop("Content-Type", "application/json")) cls: ClsType[None] = kwargs.pop("cls", None) - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_extensible_put_known_value_request( content_type=content_type, @@ -231,7 +231,7 @@ async def put_unknown_value( # pylint: disable=inconsistent-return-statements content_type: str = kwargs.pop("content_type", _headers.pop("Content-Type", "application/json")) cls: ClsType[None] = kwargs.pop("cls", None) - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_extensible_put_unknown_value_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/typetest-enum-fixed/typetest/enum/fixed/_model_base.py b/packages/typespec-python/test/generated/typetest-enum-fixed/typetest/enum/fixed/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/typetest-enum-fixed/typetest/enum/fixed/_model_base.py +++ b/packages/typespec-python/test/generated/typetest-enum-fixed/typetest/enum/fixed/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/typetest-enum-fixed/typetest/enum/fixed/_operations/_operations.py b/packages/typespec-python/test/generated/typetest-enum-fixed/typetest/enum/fixed/_operations/_operations.py index 9c2afb65d6d..8513b4be06e 100644 --- a/packages/typespec-python/test/generated/typetest-enum-fixed/typetest/enum/fixed/_operations/_operations.py +++ b/packages/typespec-python/test/generated/typetest-enum-fixed/typetest/enum/fixed/_operations/_operations.py @@ -161,7 +161,7 @@ def put_known_value( # pylint: disable=inconsistent-return-statements content_type: str = kwargs.pop("content_type", _headers.pop("Content-Type", "application/json")) cls: ClsType[None] = kwargs.pop("cls", None) - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_fixed_put_known_value_request( content_type=content_type, @@ -219,7 +219,7 @@ def put_unknown_value( # pylint: disable=inconsistent-return-statements content_type: str = kwargs.pop("content_type", _headers.pop("Content-Type", "application/json")) cls: ClsType[None] = kwargs.pop("cls", None) - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_fixed_put_unknown_value_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/typetest-enum-fixed/typetest/enum/fixed/aio/_operations/_operations.py b/packages/typespec-python/test/generated/typetest-enum-fixed/typetest/enum/fixed/aio/_operations/_operations.py index 60b0427db76..465010b0954 100644 --- a/packages/typespec-python/test/generated/typetest-enum-fixed/typetest/enum/fixed/aio/_operations/_operations.py +++ b/packages/typespec-python/test/generated/typetest-enum-fixed/typetest/enum/fixed/aio/_operations/_operations.py @@ -120,7 +120,7 @@ async def put_known_value( # pylint: disable=inconsistent-return-statements content_type: str = kwargs.pop("content_type", _headers.pop("Content-Type", "application/json")) cls: ClsType[None] = kwargs.pop("cls", None) - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_fixed_put_known_value_request( content_type=content_type, @@ -178,7 +178,7 @@ async def put_unknown_value( # pylint: disable=inconsistent-return-statements content_type: str = kwargs.pop("content_type", _headers.pop("Content-Type", "application/json")) cls: ClsType[None] = kwargs.pop("cls", None) - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_fixed_put_unknown_value_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/typetest-model-empty/typetest/model/empty/_model_base.py b/packages/typespec-python/test/generated/typetest-model-empty/typetest/model/empty/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/typetest-model-empty/typetest/model/empty/_model_base.py +++ b/packages/typespec-python/test/generated/typetest-model-empty/typetest/model/empty/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/typetest-model-empty/typetest/model/empty/_operations/_operations.py b/packages/typespec-python/test/generated/typetest-model-empty/typetest/model/empty/_operations/_operations.py index 5e71ff78f0e..17b1cfc095a 100644 --- a/packages/typespec-python/test/generated/typetest-model-empty/typetest/model/empty/_operations/_operations.py +++ b/packages/typespec-python/test/generated/typetest-model-empty/typetest/model/empty/_operations/_operations.py @@ -177,7 +177,7 @@ def put_empty( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_empty_put_empty_request( content_type=content_type, @@ -345,7 +345,7 @@ def post_round_trip_empty( if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_empty_post_round_trip_empty_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/typetest-model-empty/typetest/model/empty/aio/_operations/_operations.py b/packages/typespec-python/test/generated/typetest-model-empty/typetest/model/empty/aio/_operations/_operations.py index f9755259a1c..59fcc2e2b6b 100644 --- a/packages/typespec-python/test/generated/typetest-model-empty/typetest/model/empty/aio/_operations/_operations.py +++ b/packages/typespec-python/test/generated/typetest-model-empty/typetest/model/empty/aio/_operations/_operations.py @@ -133,7 +133,7 @@ async def put_empty( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_empty_put_empty_request( content_type=content_type, @@ -301,7 +301,7 @@ async def post_round_trip_empty( if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_empty_post_round_trip_empty_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/typetest-model-nesteddiscriminator/typetest/model/nesteddiscriminator/_model_base.py b/packages/typespec-python/test/generated/typetest-model-nesteddiscriminator/typetest/model/nesteddiscriminator/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/typetest-model-nesteddiscriminator/typetest/model/nesteddiscriminator/_model_base.py +++ b/packages/typespec-python/test/generated/typetest-model-nesteddiscriminator/typetest/model/nesteddiscriminator/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/typetest-model-nesteddiscriminator/typetest/model/nesteddiscriminator/_operations/_operations.py b/packages/typespec-python/test/generated/typetest-model-nesteddiscriminator/typetest/model/nesteddiscriminator/_operations/_operations.py index bea01b32d7c..ab903abfbf7 100644 --- a/packages/typespec-python/test/generated/typetest-model-nesteddiscriminator/typetest/model/nesteddiscriminator/_operations/_operations.py +++ b/packages/typespec-python/test/generated/typetest-model-nesteddiscriminator/typetest/model/nesteddiscriminator/_operations/_operations.py @@ -276,7 +276,7 @@ def put_model( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_nested_discriminator_put_model_request( content_type=content_type, @@ -444,7 +444,7 @@ def put_recursive_model( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_nested_discriminator_put_recursive_model_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/typetest-model-nesteddiscriminator/typetest/model/nesteddiscriminator/aio/_operations/_operations.py b/packages/typespec-python/test/generated/typetest-model-nesteddiscriminator/typetest/model/nesteddiscriminator/aio/_operations/_operations.py index 4f0891b2efb..96f23d929f3 100644 --- a/packages/typespec-python/test/generated/typetest-model-nesteddiscriminator/typetest/model/nesteddiscriminator/aio/_operations/_operations.py +++ b/packages/typespec-python/test/generated/typetest-model-nesteddiscriminator/typetest/model/nesteddiscriminator/aio/_operations/_operations.py @@ -188,7 +188,7 @@ async def put_model( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_nested_discriminator_put_model_request( content_type=content_type, @@ -356,7 +356,7 @@ async def put_recursive_model( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_nested_discriminator_put_recursive_model_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/typetest-model-notdiscriminated/typetest/model/notdiscriminated/_model_base.py b/packages/typespec-python/test/generated/typetest-model-notdiscriminated/typetest/model/notdiscriminated/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/typetest-model-notdiscriminated/typetest/model/notdiscriminated/_model_base.py +++ b/packages/typespec-python/test/generated/typetest-model-notdiscriminated/typetest/model/notdiscriminated/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/typetest-model-notdiscriminated/typetest/model/notdiscriminated/_operations/_operations.py b/packages/typespec-python/test/generated/typetest-model-notdiscriminated/typetest/model/notdiscriminated/_operations/_operations.py index 8e729191370..cddd71c3562 100644 --- a/packages/typespec-python/test/generated/typetest-model-notdiscriminated/typetest/model/notdiscriminated/_operations/_operations.py +++ b/packages/typespec-python/test/generated/typetest-model-notdiscriminated/typetest/model/notdiscriminated/_operations/_operations.py @@ -177,7 +177,7 @@ def post_valid( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_not_discriminated_post_valid_request( content_type=content_type, @@ -339,7 +339,7 @@ def put_valid(self, input: Union[_models.Siamese, JSON, IO], **kwargs: Any) -> _ if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_not_discriminated_put_valid_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/typetest-model-notdiscriminated/typetest/model/notdiscriminated/aio/_operations/_operations.py b/packages/typespec-python/test/generated/typetest-model-notdiscriminated/typetest/model/notdiscriminated/aio/_operations/_operations.py index fc5a170e549..e96ae0804ec 100644 --- a/packages/typespec-python/test/generated/typetest-model-notdiscriminated/typetest/model/notdiscriminated/aio/_operations/_operations.py +++ b/packages/typespec-python/test/generated/typetest-model-notdiscriminated/typetest/model/notdiscriminated/aio/_operations/_operations.py @@ -133,7 +133,7 @@ async def post_valid( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_not_discriminated_post_valid_request( content_type=content_type, @@ -295,7 +295,7 @@ async def put_valid(self, input: Union[_models.Siamese, JSON, IO], **kwargs: Any if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_not_discriminated_put_valid_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/typetest-model-singlediscriminator/typetest/model/singlediscriminator/_model_base.py b/packages/typespec-python/test/generated/typetest-model-singlediscriminator/typetest/model/singlediscriminator/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/typetest-model-singlediscriminator/typetest/model/singlediscriminator/_model_base.py +++ b/packages/typespec-python/test/generated/typetest-model-singlediscriminator/typetest/model/singlediscriminator/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/typetest-model-singlediscriminator/typetest/model/singlediscriminator/_operations/_operations.py b/packages/typespec-python/test/generated/typetest-model-singlediscriminator/typetest/model/singlediscriminator/_operations/_operations.py index 8d7f63c7baf..a4c8199b043 100644 --- a/packages/typespec-python/test/generated/typetest-model-singlediscriminator/typetest/model/singlediscriminator/_operations/_operations.py +++ b/packages/typespec-python/test/generated/typetest-model-singlediscriminator/typetest/model/singlediscriminator/_operations/_operations.py @@ -290,7 +290,7 @@ def put_model( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_single_discriminator_put_model_request( content_type=content_type, @@ -458,7 +458,7 @@ def put_recursive_model( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_single_discriminator_put_recursive_model_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/typetest-model-singlediscriminator/typetest/model/singlediscriminator/aio/_operations/_operations.py b/packages/typespec-python/test/generated/typetest-model-singlediscriminator/typetest/model/singlediscriminator/aio/_operations/_operations.py index f34b06087d6..872ad6dc1c5 100644 --- a/packages/typespec-python/test/generated/typetest-model-singlediscriminator/typetest/model/singlediscriminator/aio/_operations/_operations.py +++ b/packages/typespec-python/test/generated/typetest-model-singlediscriminator/typetest/model/singlediscriminator/aio/_operations/_operations.py @@ -189,7 +189,7 @@ async def put_model( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_single_discriminator_put_model_request( content_type=content_type, @@ -357,7 +357,7 @@ async def put_recursive_model( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_single_discriminator_put_recursive_model_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/typetest-model-usage/typetest/model/usage/_model_base.py b/packages/typespec-python/test/generated/typetest-model-usage/typetest/model/usage/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/typetest-model-usage/typetest/model/usage/_model_base.py +++ b/packages/typespec-python/test/generated/typetest-model-usage/typetest/model/usage/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/typetest-model-usage/typetest/model/usage/_operations/_operations.py b/packages/typespec-python/test/generated/typetest-model-usage/typetest/model/usage/_operations/_operations.py index fbaffabe16b..68dfcb15045 100644 --- a/packages/typespec-python/test/generated/typetest-model-usage/typetest/model/usage/_operations/_operations.py +++ b/packages/typespec-python/test/generated/typetest-model-usage/typetest/model/usage/_operations/_operations.py @@ -177,7 +177,7 @@ def input( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_usage_input_request( content_type=content_type, @@ -345,7 +345,7 @@ def input_and_output( if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_usage_input_and_output_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/typetest-model-usage/typetest/model/usage/aio/_operations/_operations.py b/packages/typespec-python/test/generated/typetest-model-usage/typetest/model/usage/aio/_operations/_operations.py index ba6f6d005f9..35b3ad6f89c 100644 --- a/packages/typespec-python/test/generated/typetest-model-usage/typetest/model/usage/aio/_operations/_operations.py +++ b/packages/typespec-python/test/generated/typetest-model-usage/typetest/model/usage/aio/_operations/_operations.py @@ -133,7 +133,7 @@ async def input( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_usage_input_request( content_type=content_type, @@ -301,7 +301,7 @@ async def input_and_output( if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_usage_input_and_output_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/typetest-model-visibility/typetest/model/visibility/_model_base.py b/packages/typespec-python/test/generated/typetest-model-visibility/typetest/model/visibility/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/typetest-model-visibility/typetest/model/visibility/_model_base.py +++ b/packages/typespec-python/test/generated/typetest-model-visibility/typetest/model/visibility/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/typetest-model-visibility/typetest/model/visibility/_operations/_operations.py b/packages/typespec-python/test/generated/typetest-model-visibility/typetest/model/visibility/_operations/_operations.py index 01dfe720242..7cf8c45d865 100644 --- a/packages/typespec-python/test/generated/typetest-model-visibility/typetest/model/visibility/_operations/_operations.py +++ b/packages/typespec-python/test/generated/typetest-model-visibility/typetest/model/visibility/_operations/_operations.py @@ -215,7 +215,7 @@ def get_model(self, input: Union[_models.VisibilityModel, JSON, IO], **kwargs: A if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_visibility_get_model_request( content_type=content_type, @@ -332,7 +332,7 @@ def head_model(self, input: Union[_models.VisibilityModel, JSON, IO], **kwargs: if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_visibility_head_model_request( content_type=content_type, @@ -449,7 +449,7 @@ def put_model( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_visibility_put_model_request( content_type=content_type, @@ -565,7 +565,7 @@ def patch_model( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_visibility_patch_model_request( content_type=content_type, @@ -681,7 +681,7 @@ def post_model( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_visibility_post_model_request( content_type=content_type, @@ -797,7 +797,7 @@ def delete_model( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_visibility_delete_model_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/typetest-model-visibility/typetest/model/visibility/aio/_operations/_operations.py b/packages/typespec-python/test/generated/typetest-model-visibility/typetest/model/visibility/aio/_operations/_operations.py index dc17477ca19..8d503633c40 100644 --- a/packages/typespec-python/test/generated/typetest-model-visibility/typetest/model/visibility/aio/_operations/_operations.py +++ b/packages/typespec-python/test/generated/typetest-model-visibility/typetest/model/visibility/aio/_operations/_operations.py @@ -136,7 +136,7 @@ async def get_model( if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_visibility_get_model_request( content_type=content_type, @@ -253,7 +253,7 @@ async def head_model(self, input: Union[_models.VisibilityModel, JSON, IO], **kw if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_visibility_head_model_request( content_type=content_type, @@ -370,7 +370,7 @@ async def put_model( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_visibility_put_model_request( content_type=content_type, @@ -486,7 +486,7 @@ async def patch_model( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_visibility_patch_model_request( content_type=content_type, @@ -602,7 +602,7 @@ async def post_model( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_visibility_post_model_request( content_type=content_type, @@ -718,7 +718,7 @@ async def delete_model( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_visibility_delete_model_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/typetest-property-nullable/typetest/property/nullable/_model_base.py b/packages/typespec-python/test/generated/typetest-property-nullable/typetest/property/nullable/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/typetest-property-nullable/typetest/property/nullable/_model_base.py +++ b/packages/typespec-python/test/generated/typetest-property-nullable/typetest/property/nullable/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/typetest-property-nullable/typetest/property/nullable/aio/operations/_operations.py b/packages/typespec-python/test/generated/typetest-property-nullable/typetest/property/nullable/aio/operations/_operations.py index b595ad795cd..b526ea2db5a 100644 --- a/packages/typespec-python/test/generated/typetest-property-nullable/typetest/property/nullable/aio/operations/_operations.py +++ b/packages/typespec-python/test/generated/typetest-property-nullable/typetest/property/nullable/aio/operations/_operations.py @@ -272,7 +272,7 @@ async def patch_non_null( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_string_patch_non_null_request( content_type=content_type, @@ -387,7 +387,7 @@ async def patch_null( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_string_patch_null_request( content_type=content_type, @@ -624,7 +624,7 @@ async def patch_non_null( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_bytes_patch_non_null_request( content_type=content_type, @@ -739,7 +739,7 @@ async def patch_null( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_bytes_patch_null_request( content_type=content_type, @@ -976,7 +976,7 @@ async def patch_non_null( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_datetime_patch_non_null_request( content_type=content_type, @@ -1091,7 +1091,7 @@ async def patch_null( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_datetime_patch_null_request( content_type=content_type, @@ -1328,7 +1328,7 @@ async def patch_non_null( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_duration_patch_non_null_request( content_type=content_type, @@ -1443,7 +1443,7 @@ async def patch_null( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_duration_patch_null_request( content_type=content_type, @@ -1684,7 +1684,7 @@ async def patch_non_null( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_collections_byte_patch_non_null_request( content_type=content_type, @@ -1803,7 +1803,7 @@ async def patch_null( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_collections_byte_patch_null_request( content_type=content_type, @@ -2046,7 +2046,7 @@ async def patch_non_null( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_collections_model_patch_non_null_request( content_type=content_type, @@ -2165,7 +2165,7 @@ async def patch_null( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_collections_model_patch_null_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/typetest-property-nullable/typetest/property/nullable/operations/_operations.py b/packages/typespec-python/test/generated/typetest-property-nullable/typetest/property/nullable/operations/_operations.py index ff5823d70a5..2ec0aac155a 100644 --- a/packages/typespec-python/test/generated/typetest-property-nullable/typetest/property/nullable/operations/_operations.py +++ b/packages/typespec-python/test/generated/typetest-property-nullable/typetest/property/nullable/operations/_operations.py @@ -586,7 +586,7 @@ def patch_non_null( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_string_patch_non_null_request( content_type=content_type, @@ -701,7 +701,7 @@ def patch_null( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_string_patch_null_request( content_type=content_type, @@ -938,7 +938,7 @@ def patch_non_null( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_bytes_patch_non_null_request( content_type=content_type, @@ -1053,7 +1053,7 @@ def patch_null( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_bytes_patch_null_request( content_type=content_type, @@ -1290,7 +1290,7 @@ def patch_non_null( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_datetime_patch_non_null_request( content_type=content_type, @@ -1405,7 +1405,7 @@ def patch_null( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_datetime_patch_null_request( content_type=content_type, @@ -1642,7 +1642,7 @@ def patch_non_null( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_duration_patch_non_null_request( content_type=content_type, @@ -1757,7 +1757,7 @@ def patch_null( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_duration_patch_null_request( content_type=content_type, @@ -1998,7 +1998,7 @@ def patch_non_null( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_collections_byte_patch_non_null_request( content_type=content_type, @@ -2117,7 +2117,7 @@ def patch_null( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_collections_byte_patch_null_request( content_type=content_type, @@ -2360,7 +2360,7 @@ def patch_non_null( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_collections_model_patch_non_null_request( content_type=content_type, @@ -2479,7 +2479,7 @@ def patch_null( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_collections_model_patch_null_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/typetest-property-optional/typetest/property/optional/_model_base.py b/packages/typespec-python/test/generated/typetest-property-optional/typetest/property/optional/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/typetest-property-optional/typetest/property/optional/_model_base.py +++ b/packages/typespec-python/test/generated/typetest-property-optional/typetest/property/optional/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/typetest-property-optional/typetest/property/optional/aio/operations/_operations.py b/packages/typespec-python/test/generated/typetest-property-optional/typetest/property/optional/aio/operations/_operations.py index dcb9dfb01b3..b7d7c006a4c 100644 --- a/packages/typespec-python/test/generated/typetest-property-optional/typetest/property/optional/aio/operations/_operations.py +++ b/packages/typespec-python/test/generated/typetest-property-optional/typetest/property/optional/aio/operations/_operations.py @@ -277,7 +277,7 @@ async def put_all( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_string_put_all_request( content_type=content_type, @@ -393,7 +393,7 @@ async def put_default( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_string_put_default_request( content_type=content_type, @@ -631,7 +631,7 @@ async def put_all( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_bytes_put_all_request( content_type=content_type, @@ -747,7 +747,7 @@ async def put_default( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_bytes_put_default_request( content_type=content_type, @@ -985,7 +985,7 @@ async def put_all( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_datetime_put_all_request( content_type=content_type, @@ -1101,7 +1101,7 @@ async def put_default( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_datetime_put_default_request( content_type=content_type, @@ -1339,7 +1339,7 @@ async def put_all( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_duration_put_all_request( content_type=content_type, @@ -1455,7 +1455,7 @@ async def put_default( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_duration_put_default_request( content_type=content_type, @@ -1693,7 +1693,7 @@ async def put_all( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_collections_byte_put_all_request( content_type=content_type, @@ -1809,7 +1809,7 @@ async def put_default( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_collections_byte_put_default_request( content_type=content_type, @@ -2049,7 +2049,7 @@ async def put_all( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_collections_model_put_all_request( content_type=content_type, @@ -2165,7 +2165,7 @@ async def put_default( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_collections_model_put_default_request( content_type=content_type, @@ -2405,7 +2405,7 @@ async def put_all( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_required_and_optional_put_all_request( content_type=content_type, @@ -2521,7 +2521,7 @@ async def put_required_only( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_required_and_optional_put_required_only_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/typetest-property-optional/typetest/property/optional/operations/_operations.py b/packages/typespec-python/test/generated/typetest-property-optional/typetest/property/optional/operations/_operations.py index b7930543d18..4cfed78fac1 100644 --- a/packages/typespec-python/test/generated/typetest-property-optional/typetest/property/optional/operations/_operations.py +++ b/packages/typespec-python/test/generated/typetest-property-optional/typetest/property/optional/operations/_operations.py @@ -647,7 +647,7 @@ def put_all( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_string_put_all_request( content_type=content_type, @@ -763,7 +763,7 @@ def put_default( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_string_put_default_request( content_type=content_type, @@ -1001,7 +1001,7 @@ def put_all( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_bytes_put_all_request( content_type=content_type, @@ -1117,7 +1117,7 @@ def put_default( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_bytes_put_default_request( content_type=content_type, @@ -1355,7 +1355,7 @@ def put_all( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_datetime_put_all_request( content_type=content_type, @@ -1471,7 +1471,7 @@ def put_default( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_datetime_put_default_request( content_type=content_type, @@ -1709,7 +1709,7 @@ def put_all( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_duration_put_all_request( content_type=content_type, @@ -1825,7 +1825,7 @@ def put_default( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_duration_put_default_request( content_type=content_type, @@ -2063,7 +2063,7 @@ def put_all( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_collections_byte_put_all_request( content_type=content_type, @@ -2179,7 +2179,7 @@ def put_default( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_collections_byte_put_default_request( content_type=content_type, @@ -2419,7 +2419,7 @@ def put_all( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_collections_model_put_all_request( content_type=content_type, @@ -2535,7 +2535,7 @@ def put_default( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_collections_model_put_default_request( content_type=content_type, @@ -2775,7 +2775,7 @@ def put_all( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_required_and_optional_put_all_request( content_type=content_type, @@ -2891,7 +2891,7 @@ def put_required_only( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_required_and_optional_put_required_only_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/typetest-property-valuetypes/typetest/property/valuetypes/_model_base.py b/packages/typespec-python/test/generated/typetest-property-valuetypes/typetest/property/valuetypes/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/typetest-property-valuetypes/typetest/property/valuetypes/_model_base.py +++ b/packages/typespec-python/test/generated/typetest-property-valuetypes/typetest/property/valuetypes/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/typetest-property-valuetypes/typetest/property/valuetypes/aio/operations/_operations.py b/packages/typespec-python/test/generated/typetest-property-valuetypes/typetest/property/valuetypes/aio/operations/_operations.py index 6c3db6f8be6..141c8da366c 100644 --- a/packages/typespec-python/test/generated/typetest-property-valuetypes/typetest/property/valuetypes/aio/operations/_operations.py +++ b/packages/typespec-python/test/generated/typetest-property-valuetypes/typetest/property/valuetypes/aio/operations/_operations.py @@ -235,7 +235,7 @@ async def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_boolean_put_request( content_type=content_type, @@ -421,7 +421,7 @@ async def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_string_put_request( content_type=content_type, @@ -607,7 +607,7 @@ async def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_bytes_put_request( content_type=content_type, @@ -793,7 +793,7 @@ async def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_int_put_request( content_type=content_type, @@ -979,7 +979,7 @@ async def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_float_put_request( content_type=content_type, @@ -1165,7 +1165,7 @@ async def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_datetime_put_request( content_type=content_type, @@ -1351,7 +1351,7 @@ async def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_duration_put_request( content_type=content_type, @@ -1537,7 +1537,7 @@ async def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_enum_put_request( content_type=content_type, @@ -1723,7 +1723,7 @@ async def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_extensible_enum_put_request( content_type=content_type, @@ -1909,7 +1909,7 @@ async def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_model_put_request( content_type=content_type, @@ -2096,7 +2096,7 @@ async def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_collections_string_put_request( content_type=content_type, @@ -2282,7 +2282,7 @@ async def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_collections_int_put_request( content_type=content_type, @@ -2469,7 +2469,7 @@ async def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_collections_model_put_request( content_type=content_type, @@ -2656,7 +2656,7 @@ async def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_dictionary_string_put_request( content_type=content_type, @@ -2842,7 +2842,7 @@ async def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_never_put_request( content_type=content_type, @@ -3028,7 +3028,7 @@ async def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_unknown_string_put_request( content_type=content_type, @@ -3214,7 +3214,7 @@ async def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_unknown_int_put_request( content_type=content_type, @@ -3400,7 +3400,7 @@ async def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_unknown_dict_put_request( content_type=content_type, @@ -3586,7 +3586,7 @@ async def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_unknown_array_put_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/typetest-property-valuetypes/typetest/property/valuetypes/operations/_operations.py b/packages/typespec-python/test/generated/typetest-property-valuetypes/typetest/property/valuetypes/operations/_operations.py index b98d01af050..94164fc701d 100644 --- a/packages/typespec-python/test/generated/typetest-property-valuetypes/typetest/property/valuetypes/operations/_operations.py +++ b/packages/typespec-python/test/generated/typetest-property-valuetypes/typetest/property/valuetypes/operations/_operations.py @@ -731,7 +731,7 @@ def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_boolean_put_request( content_type=content_type, @@ -917,7 +917,7 @@ def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_string_put_request( content_type=content_type, @@ -1103,7 +1103,7 @@ def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_bytes_put_request( content_type=content_type, @@ -1289,7 +1289,7 @@ def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_int_put_request( content_type=content_type, @@ -1475,7 +1475,7 @@ def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_float_put_request( content_type=content_type, @@ -1661,7 +1661,7 @@ def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_datetime_put_request( content_type=content_type, @@ -1847,7 +1847,7 @@ def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_duration_put_request( content_type=content_type, @@ -2033,7 +2033,7 @@ def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_enum_put_request( content_type=content_type, @@ -2219,7 +2219,7 @@ def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_extensible_enum_put_request( content_type=content_type, @@ -2405,7 +2405,7 @@ def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_model_put_request( content_type=content_type, @@ -2592,7 +2592,7 @@ def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_collections_string_put_request( content_type=content_type, @@ -2778,7 +2778,7 @@ def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_collections_int_put_request( content_type=content_type, @@ -2965,7 +2965,7 @@ def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_collections_model_put_request( content_type=content_type, @@ -3152,7 +3152,7 @@ def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_dictionary_string_put_request( content_type=content_type, @@ -3338,7 +3338,7 @@ def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_never_put_request( content_type=content_type, @@ -3524,7 +3524,7 @@ def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_unknown_string_put_request( content_type=content_type, @@ -3710,7 +3710,7 @@ def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_unknown_int_put_request( content_type=content_type, @@ -3896,7 +3896,7 @@ def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_unknown_dict_put_request( content_type=content_type, @@ -4082,7 +4082,7 @@ def put( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _content = json.dumps(body, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(body, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_unknown_array_put_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/typetest-union/typetest/union/_model_base.py b/packages/typespec-python/test/generated/typetest-union/typetest/union/_model_base.py index 48acb463cae..bbb1857dd64 100644 --- a/packages/typespec-python/test/generated/typetest-union/typetest/union/_model_base.py +++ b/packages/typespec-python/test/generated/typetest-union/typetest/union/_model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,21 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({k: v for k, v in sys.modules[module_end].__dict__.items() if isinstance(v, type)}) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +477,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +515,54 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if ( + isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators + ): # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [Model._as_dict_value(x, exclude_readonly=exclude_readonly) for x in v] + if isinstance(v, dict): + return {dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) for dk, dv in v.items()} + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +570,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +594,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +612,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +628,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +730,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +780,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +792,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/generated/typetest-union/typetest/union/_operations/_operations.py b/packages/typespec-python/test/generated/typetest-union/typetest/union/_operations/_operations.py index defadd84a2f..5ba440f5f39 100644 --- a/packages/typespec-python/test/generated/typetest-union/typetest/union/_operations/_operations.py +++ b/packages/typespec-python/test/generated/typetest-union/typetest/union/_operations/_operations.py @@ -188,7 +188,7 @@ def send_int( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_union_send_int_request( content_type=content_type, @@ -304,7 +304,7 @@ def send_int_array( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_union_send_int_array_request( content_type=content_type, @@ -420,7 +420,7 @@ def send_first_named_union_value( # pylint: disable=inconsistent-return-stateme if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_union_send_first_named_union_value_request( content_type=content_type, @@ -536,7 +536,7 @@ def send_second_named_union_value( # pylint: disable=inconsistent-return-statem if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_union_send_second_named_union_value_request( content_type=content_type, diff --git a/packages/typespec-python/test/generated/typetest-union/typetest/union/aio/_operations/_operations.py b/packages/typespec-python/test/generated/typetest-union/typetest/union/aio/_operations/_operations.py index 3a4d3c32ed9..d758589f218 100644 --- a/packages/typespec-python/test/generated/typetest-union/typetest/union/aio/_operations/_operations.py +++ b/packages/typespec-python/test/generated/typetest-union/typetest/union/aio/_operations/_operations.py @@ -134,7 +134,7 @@ async def send_int( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_union_send_int_request( content_type=content_type, @@ -250,7 +250,7 @@ async def send_int_array( # pylint: disable=inconsistent-return-statements if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_union_send_int_array_request( content_type=content_type, @@ -366,7 +366,7 @@ async def send_first_named_union_value( # pylint: disable=inconsistent-return-s if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_union_send_first_named_union_value_request( content_type=content_type, @@ -482,7 +482,7 @@ async def send_second_named_union_value( # pylint: disable=inconsistent-return- if isinstance(input, (IOBase, bytes)): _content = input else: - _content = json.dumps(input, cls=AzureJSONEncoder) # type: ignore + _content = json.dumps(input, cls=AzureJSONEncoder, exclude_readonly=True) # type: ignore request = build_union_send_second_named_union_value_request( content_type=content_type, diff --git a/packages/typespec-python/test/unittests/generated/model_base.py b/packages/typespec-python/test/unittests/generated/model_base.py index 48acb463cae..17c08a23b49 100644 --- a/packages/typespec-python/test/unittests/generated/model_base.py +++ b/packages/typespec-python/test/unittests/generated/model_base.py @@ -128,10 +128,16 @@ def _is_readonly(p): class AzureJSONEncoder(JSONEncoder): """A JSON encoder that's capable of serializing datetime objects and bytes.""" + def __init__(self, *args, exclude_readonly: bool = False, **kwargs): + super().__init__(*args, **kwargs) + self.exclude_readonly = exclude_readonly + def default(self, o): # pylint: disable=too-many-return-statements if _is_model(o): - readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] - return {k: v for k, v in o.items() if k not in readonly_props} + if self.exclude_readonly: + readonly_props = [p._rest_name for p in o._attr_to_rest_field.values() if _is_readonly(p)] + return {k: v for k, v in o.items() if k not in readonly_props} + return dict(o.items()) if isinstance(o, (bytes, bytearray)): return base64.b64encode(o).decode() if isinstance(o, _Null): @@ -295,11 +301,29 @@ def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = return _DESERIALIZE_MAPPING.get(annotation) +def _get_type_alias_type(module_name: str, alias_name: str): + types = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, typing._GenericAlias) # type: ignore + } + if alias_name not in types: + return alias_name + return types[alias_name] + + def _get_model(module_name: str, model_name: str): - models = {k: v for k, v in sys.modules[module_name].__dict__.items() if isinstance(v, type)} + models = { + k: v + for k, v in sys.modules[module_name].__dict__.items() + if isinstance(v, type) + } module_end = module_name.rsplit(".", 1)[0] - module = sys.modules[module_end] - models.update({k: v for k, v in module.__dict__.items() if isinstance(v, type)}) + models.update({ + k: v + for k, v in sys.modules[module_end].__dict__.items() + if isinstance(v, type) + }) if isinstance(model_name, str): model_name = model_name.split(".")[-1] if model_name not in models: @@ -461,7 +485,7 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: raise TypeError(f"{class_name}.__init__() got an unexpected keyword argument '{non_attr_kwargs[0]}'") dict_to_pass.update( { - self._attr_to_rest_field[k]._rest_name: _serialize(v, self._attr_to_rest_field[k]._format) + self._attr_to_rest_field[k]._rest_name: _create_value(self._attr_to_rest_field[k], v) for k, v in kwargs.items() if v is not None } @@ -499,24 +523,60 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None: base.__mapping__[discriminator or cls.__name__] = cls # type: ignore # pylint: disable=no-member @classmethod - def _get_discriminator(cls) -> typing.Optional[str]: + def _get_discriminator(cls, exist_discriminators) -> typing.Optional[str]: for v in cls.__dict__.values(): - if isinstance(v, _RestField) and v._is_discriminator: # pylint: disable=protected-access + if isinstance(v, _RestField) and v._is_discriminator and v._rest_name not in exist_discriminators: # pylint: disable=protected-access return v._rest_name # pylint: disable=protected-access return None @classmethod - def _deserialize(cls, data): + def _deserialize(cls, data, exist_discriminators): if not hasattr(cls, "__mapping__"): # pylint: disable=no-member return cls(data) - discriminator = cls._get_discriminator() - mapped_cls = cls.__mapping__.get(data.get(discriminator), cls) # pylint: disable=no-member + discriminator = cls._get_discriminator(exist_discriminators) + exist_discriminators.append(discriminator) + mapped_cls = cls.__mapping__.get( + data.get(discriminator), cls + ) # pylint: disable=no-member if mapped_cls == cls: return cls(data) - return mapped_cls._deserialize(data) # pylint: disable=protected-access - - -def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-return-statements, too-many-statements + return mapped_cls._deserialize(data, exist_discriminators) # pylint: disable=protected-access + + def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.Any]: + """Return a dict that can be JSONify using json.dump. + + :keyword bool exclude_readonly: Whether to remove the readonly properties. + :returns: A dict JSON compatible object + :rtype: dict + """ + + result = {} + if exclude_readonly: + readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)] + for k, v in self.items(): + if exclude_readonly and k in readonly_props: # pyright: reportUnboundVariable=false + continue + result[k] = Model._as_dict_value(v, exclude_readonly=exclude_readonly) + return result + + @staticmethod + def _as_dict_value(v: typing.Any, exclude_readonly: bool = False) -> typing.Any: + if v is None or isinstance(v, _Null): + return None + if isinstance(v, (list, tuple, set)): + return [ + Model._as_dict_value(x, exclude_readonly=exclude_readonly) + for x in v + ] + if isinstance(v, dict): + return { + dk: Model._as_dict_value(dv, exclude_readonly=exclude_readonly) + for dk, dv in v.items() + } + return v.as_dict(exclude_readonly=exclude_readonly) if hasattr(v, "as_dict") else v + + +def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915, R0912 annotation: typing.Any, module: typing.Optional[str], rf: typing.Optional["_RestField"] = None, @@ -524,8 +584,22 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=too-many-retur if not annotation or annotation in [int, float]: return None + # is it a type alias? + if isinstance(annotation, str): + if module is not None: + annotation = _get_type_alias_type(module, annotation) + + # is it a forward ref / in quotes? + if isinstance(annotation, (str, typing.ForwardRef)): + try: + model_name = annotation.__forward_arg__ # type: ignore + except AttributeError: + model_name = annotation + if module is not None: + annotation = _get_model(module, model_name) + try: - if module and _is_model(_get_model(module, annotation)): + if module and _is_model(annotation): if rf: rf._is_model = True @@ -534,7 +608,7 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj return obj return _deserialize(model_deserializer, obj) - return functools.partial(_deserialize_model, _get_model(module, annotation)) + return functools.partial(_deserialize_model, annotation) except Exception: pass @@ -552,22 +626,8 @@ def _deserialize_model(model_deserializer: typing.Optional[typing.Callable], obj except AttributeError: pass - if getattr(annotation, "__origin__", None) is typing.Union: - - def _deserialize_with_union(union_annotation, obj): - for t in union_annotation.__args__: - try: - return _deserialize(t, obj, module, rf) - except DeserializationError: - pass - raise DeserializationError() - - return functools.partial(_deserialize_with_union, annotation) - # is it optional? try: - # right now, assuming we don't have unions, since we're getting rid of the only - # union we used to have in msrest models, which was union of str and enum if any(a for a in annotation.__args__ if a == type(None)): if_obj_deserializer = _get_deserialize_callable_from_annotation( next(a for a in annotation.__args__ if a != type(None)), module, rf @@ -582,14 +642,18 @@ def _deserialize_with_optional(if_obj_deserializer: typing.Optional[typing.Calla except AttributeError: pass - # is it a forward ref / in quotes? - if isinstance(annotation, (str, typing.ForwardRef)): - try: - model_name = annotation.__forward_arg__ # type: ignore - except AttributeError: - model_name = annotation - if module is not None: - annotation = _get_model(module, model_name) + if getattr(annotation, "__origin__", None) is typing.Union: + deserializers = [_get_deserialize_callable_from_annotation(arg, module, rf) for arg in annotation.__args__] + + def _deserialize_with_union(deserializers, obj): + for deserializer in deserializers: + try: + return _deserialize(deserializer, obj) + except DeserializationError: + pass + raise DeserializationError() + + return functools.partial(_deserialize_with_union, deserializers) try: if annotation._name == "Dict": @@ -680,7 +744,7 @@ def _deserialize_with_callable( # for unknown value, return raw value return value if isinstance(deserializer, type) and issubclass(deserializer, Model): - return deserializer._deserialize(value) + return deserializer._deserialize(value, []) return typing.cast(typing.Callable[[typing.Any], typing.Any], deserializer)(value) except Exception as e: raise DeserializationError() from e @@ -730,6 +794,8 @@ def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin item = obj.get(self._rest_name) if item is None: return item + if self._is_model: + return item return _deserialize(self._type, _serialize(item, self._format), rf=self) def __set__(self, obj: Model, value) -> None: @@ -740,8 +806,11 @@ def __set__(self, obj: Model, value) -> None: except KeyError: pass return - if self._is_model and not _is_model(value): - obj.__setitem__(self._rest_name, _deserialize(self._type, value)) + if self._is_model: + if not _is_model(value): + value = _deserialize(self._type, value) + obj.__setitem__(self._rest_name, value) + return obj.__setitem__(self._rest_name, _serialize(value, self._format)) def _get_deserialize_callable_from_annotation( diff --git a/packages/typespec-python/test/unittests/test_model_base_serialization.py b/packages/typespec-python/test/unittests/test_model_base_serialization.py index 03b1653b05b..06474fee6a9 100644 --- a/packages/typespec-python/test/unittests/test_model_base_serialization.py +++ b/packages/typespec-python/test/unittests/test_model_base_serialization.py @@ -5,10 +5,12 @@ import copy import json import datetime -from typing import Any, Iterable, List, Literal, Dict, Mapping, Sequence, Set, Tuple, Optional, overload +from typing import Any, Iterable, List, Literal, Dict, Mapping, Sequence, Set, Tuple, Optional, overload, Union import pytest import isodate -from generated.model_base import AzureJSONEncoder, Model, rest_field +from azure.core.serialization import NULL + +from generated.model_base import AzureJSONEncoder, Model, rest_field, _is_model, rest_discriminator class BasicResource(Model): @@ -873,8 +875,7 @@ def test_model_recursion_complex(): assert isinstance(model.list_of_dict_of_me[0], Dict) assert isinstance(model.list_of_dict_of_me[0]["me"], RecursiveModel) - assert json.loads(json.dumps(dict(model))) == model == dict_response - assert json.loads(json.dumps(model, cls=AzureJSONEncoder)) == model == dict_response + assert model.as_dict() == model == dict_response def test_literals(): @@ -1855,9 +1856,26 @@ def test_deserialization_is(): assert x.y.z.zval == isodate.parse_datetime(serialized_datetime) +class InnerModelWithReadonly(Model): + normal_property: str = rest_field(name="normalProperty") + readonly_property: str = rest_field(name="readonlyProperty", visibility=["read"]) + + @overload + def __init__(self, *, normal_property: str): + ... + + @overload + def __init__(self, mapping: Mapping[str, Any], /): + ... + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + class ModelWithReadonly(Model): normal_property: str = rest_field(name="normalProperty") readonly_property: str = rest_field(name="readonlyProperty", visibility=["read"]) + inner_model: InnerModelWithReadonly = rest_field(name="innerModel") @overload def __init__(self, *, normal_property: str): @@ -1873,25 +1891,62 @@ def __init__(self, *args, **kwargs): def test_readonly(): # we pass the dict to json, so readonly shouldn't show up in the JSON version - model = ModelWithReadonly({"normalProperty": "normal", "readonlyProperty": "readonly"}) - assert json.loads(json.dumps(model, cls=AzureJSONEncoder)) == {"normalProperty": "normal"} - assert model == {"normalProperty": "normal", "readonlyProperty": "readonly"} + value = { + "normalProperty": "normal", + "readonlyProperty": "readonly", + "innerModel": { + "normalProperty": "normal", + "readonlyProperty": "readonly" + } + } + model = ModelWithReadonly(value) + assert model.as_dict(exclude_readonly=True) == {"normalProperty": "normal", + "innerModel": {"normalProperty": "normal"}} + assert json.loads(json.dumps(model, cls=AzureJSONEncoder)) == value + assert model == value assert model["readonlyProperty"] == model.readonly_property == "readonly" + assert model["innerModel"]["readonlyProperty"] == model.inner_model.readonly_property == "readonly" def test_readonly_set(): - model = ModelWithReadonly({"normalProperty": "normal", "readonlyProperty": "readonly"}) + value = { + "normalProperty": "normal", + "readonlyProperty": "readonly", + "innerModel": { + "normalProperty": "normal", + "readonlyProperty": "readonly" + } + } + + model = ModelWithReadonly(value) assert model.normal_property == model["normalProperty"] == "normal" assert model.readonly_property == model["readonlyProperty"] == "readonly" + assert model.inner_model.normal_property == model.inner_model["normalProperty"] == "normal" + assert model.inner_model.readonly_property == model.inner_model["readonlyProperty"] == "readonly" - assert json.loads(json.dumps(model, cls=AzureJSONEncoder)) == {"normalProperty": "normal"} + assert model.as_dict(exclude_readonly=True) == {"normalProperty": "normal", + "innerModel": {"normalProperty": "normal"}} + assert json.loads(json.dumps(model, cls=AzureJSONEncoder)) == value model["normalProperty"] = "setWithDict" model["readonlyProperty"] = "setWithDict" + model.inner_model["normalProperty"] = "setWithDict" + model.inner_model["readonlyProperty"] = "setWithDict" assert model.normal_property == model["normalProperty"] == "setWithDict" assert model.readonly_property == model["readonlyProperty"] == "setWithDict" - assert json.loads(json.dumps(model, cls=AzureJSONEncoder)) == {"normalProperty": "setWithDict"} + assert model.inner_model.normal_property == model.inner_model["normalProperty"] == "setWithDict" + assert model.inner_model.readonly_property == model.inner_model["readonlyProperty"] == "setWithDict" + assert model.as_dict(exclude_readonly=True) == {"normalProperty": "setWithDict", + "innerModel": {"normalProperty": "setWithDict"}} + assert json.loads(json.dumps(model, cls=AzureJSONEncoder)) == { + "normalProperty": "setWithDict", + "readonlyProperty": "setWithDict", + "innerModel": { + "normalProperty": "setWithDict", + "readonlyProperty": "setWithDict" + } + } def test_incorrect_initialization(): @@ -3444,3 +3499,377 @@ def __init__(self, *args, **kwargs): assert model.required_property is None with pytest.raises(KeyError): model["requiredProperty"] + + +def test_null_serilization(): + dict_response = { + "name": "it's me!", + "listOfMe": [ + { + "name": "it's me!", + } + ], + "dictOfMe": { + "me": { + "name": "it's me!", + } + }, + "dictOfListOfMe": { + "many mes": [ + { + "name": "it's me!", + } + ] + }, + "listOfDictOfMe": None + } + model = RecursiveModel(dict_response) + assert json.loads(json.dumps(model, cls=AzureJSONEncoder)) == dict_response + + assert model.as_dict() == dict_response + + model.list_of_me = NULL + model.dict_of_me = None + model.list_of_dict_of_me = [ + { + "me": { + "name": "it's me!", + } + } + ] + model.dict_of_list_of_me["many mes"][0].list_of_me = NULL + model.dict_of_list_of_me["many mes"][0].dict_of_me = None + model.list_of_dict_of_me[0]["me"].list_of_me = NULL + model.list_of_dict_of_me[0]["me"].dict_of_me = None + + assert json.loads(json.dumps(model, cls=AzureJSONEncoder)) == { + "name": "it's me!", + "listOfMe": None, + "dictOfListOfMe": { + "many mes": [ + { + "name": "it's me!", + "listOfMe": None, + } + ] + }, + "listOfDictOfMe": [ + { + "me": { + "name": "it's me!", + "listOfMe": None, + } + } + ] + } + + assert model.as_dict() == { + "name": "it's me!", + "listOfMe": None, + "dictOfListOfMe": { + "many mes": [ + { + "name": "it's me!", + "listOfMe": None, + } + ] + }, + "listOfDictOfMe": [ + { + "me": { + "name": "it's me!", + "listOfMe": None, + } + } + ] + } + + +class UnionBaseModel(Model): + name: str = rest_field() + + @overload + def __init__(self, *, name: str): + ... + + @overload + def __init__(self, mapping: Mapping[str, Any]): + ... + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class UnionModel1(UnionBaseModel): + prop1: int = rest_field() + + @overload + def __init__(self, *, name: str, prop1: int): + ... + + @overload + def __init__(self, mapping: Mapping[str, Any]): + ... + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class UnionModel2(UnionBaseModel): + prop2: int = rest_field() + + @overload + def __init__(self, *, name: str, prop2: int): + ... + + @overload + def __init__(self, mapping: Mapping[str, Any]): + ... + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +MyNamedUnion = Union["UnionModel1", "UnionModel2"] + + +class ModelWithNamedUnionProperty(Model): + named_union: "MyNamedUnion" = rest_field(name="namedUnion") + + @overload + def __init__(self, *, named_union: "MyNamedUnion"): + ... + + @overload + def __init__(self, mapping: Mapping[str, Any]): + ... + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +class ModelWithSimpleUnionProperty(Model): + simple_union: Union[int, List[int]] = rest_field(name="simpleUnion") + + @overload + def __init__(self, *, simple_union: Union[int, List[int]]): + ... + + @overload + def __init__(self, mapping: Mapping[str, Any]): + ... + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + +def test_union(): + simple = ModelWithSimpleUnionProperty(simple_union=1) + assert simple.simple_union == simple["simpleUnion"] == 1 + simple = ModelWithSimpleUnionProperty(simple_union=[1, 2]) + assert simple.simple_union == simple["simpleUnion"] == [1, 2] + named = ModelWithNamedUnionProperty() + assert not _is_model(named.named_union) + named.named_union = UnionModel1(name="model1", prop1=1) + assert _is_model(named.named_union) + assert named.named_union == named["namedUnion"] == {"name": "model1", "prop1": 1} + named = ModelWithNamedUnionProperty(named_union=UnionModel2(name="model2", prop2=2)) + assert named.named_union == named["namedUnion"] == {"name": "model2", "prop2": 2} + named = ModelWithNamedUnionProperty({"namedUnion": {"name": "model2", "prop2": 2}}) + assert named.named_union == named["namedUnion"] == {"name": "model2", "prop2": 2} + + +def test_as_dict(): + class CatComplex(PetComplex): + color: Optional[str] = rest_field(default=None) + hates: Optional[List[DogComplex]] = rest_field(default=None, visibility=["read"]) + + @overload + def __init__( + self, + *, + id: Optional[int] = None, + name: Optional[str] = None, + food: Optional[str] = None, + color: Optional[str] = None, + hates: Optional[List[DogComplex]] = None, + ): + ... + + @overload + def __init__(self, mapping: Mapping[str, Any], /): + ... + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + model = CatComplex(id=2, name="Siameeee", hates=[ + DogComplex(id=1, name="Potato", food="tomato"), + DogComplex(id=-1, name="Tomato", food="french fries") + ]) + assert model.as_dict(exclude_readonly=True) == { + "id": 2, + "name": "Siameeee", + "color": None + } + + +class Fish(Model): + __mapping__: Dict[str, Model] = {} + age: int = rest_field() + kind: Literal[None] = rest_discriminator(name="kind") + + @overload + def __init__(self, *, age: int, ): + ... + + @overload + def __init__(self, mapping: Mapping[str, Any]): + ... + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.kind: Literal[None] = None + + +class Shark(Fish, discriminator="shark"): + __mapping__: Dict[str, Model] = {} + kind: Literal["shark"] = rest_discriminator(name="kind") + sharktype: Literal[None] = rest_discriminator(name="sharktype") + + @overload + def __init__(self, *, age: int, ): + ... + + @overload + def __init__(self, mapping: Mapping[str, Any]): + ... + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.kind: Literal["shark"] = "shark" + self.sharktype: Literal[None] = None + + +class GoblinShark(Shark, discriminator="goblin"): + sharktype: Literal["goblin"] = rest_discriminator(name="sharktype") + + @overload + def __init__(self, *, age: int, ): + ... + + @overload + def __init__(self, mapping: Mapping[str, Any]): + ... + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.sharktype: Literal["goblin"] = "goblin" + + +class Salmon(Fish, discriminator="salmon"): + kind: Literal["salmon"] = rest_discriminator(name="kind") + friends: Optional[List["Fish"]] = rest_field() + hate: Optional[Dict[str, "Fish"]] = rest_field() + partner: Optional["Fish"] = rest_field() + + @overload + def __init__( + self, + *, + age: int, + friends: Optional[List["Fish"]] = None, + hate: Optional[Dict[str, "Fish"]] = None, + partner: Optional["Fish"] = None, + ): + ... + + @overload + def __init__(self, mapping: Mapping[str, Any]): + ... + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.kind: Literal["salmon"] = "salmon" + + +class SawShark(Shark, discriminator="saw"): + sharktype: Literal["saw"] = rest_discriminator(name="sharktype") + + @overload + def __init__(self, *, age: int, ): + ... + + @overload + def __init__(self, mapping: Mapping[str, Any]): + ... + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.sharktype: Literal["saw"] = "saw" + + +def test_discriminator(): + input = { + "age": 1, + "kind": "salmon", + "partner": { + "age": 2, + "kind": "shark", + "sharktype": "saw", + }, + "friends": [ + { + "age": 2, + "kind": "salmon", + "partner": { + "age": 3, + "kind": "salmon", + }, + "hate": { + "key1": { + "age": 4, + "kind": "salmon", + }, + "key2": { + "age": 2, + "kind": "shark", + "sharktype": "goblin", + }, + }, + }, + { + "age": 3, + "kind": "shark", + "sharktype": "goblin", + }, + ], + "hate": { + "key3": { + "age": 3, + "kind": "shark", + "sharktype": "saw", + }, + "key4": { + "age": 2, + "kind": "salmon", + "friends": [ + { + "age": 1, + "kind": "salmon", + }, + { + "age": 4, + "kind": "shark", + "sharktype": "goblin", + }, + ], + }, + }, + } + + model = Salmon(input) + assert model == input + assert model.partner.age == 2 + assert model.partner == SawShark(age=2) + assert model.friends[0].hate["key2"] == GoblinShark(age=2)