Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions ChangeLog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
# Change Log

### 2020-xx-xx - 5.4.3
Autorest core version: 3.0.6320

Modelerfour version: 4.15.421

**Bug Fixes**

- Fix conflict for model deserialization when operation has input param with name `models` #819

### 2020-11-09 - 5.4.2
Autorest core version: 3.0.6320

Expand Down
2 changes: 1 addition & 1 deletion autorest/codegen/models/enum_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def operation_type_annotation(self) -> str:
:return: The type annotation for this schema
:rtype: str
"""
return f'Union[{self.enum_type.type_annotation}, "models.{self.name}"]'
return f'Union[{self.enum_type.type_annotation}, "_models.{self.name}"]'

def get_declaration(self, value: Any) -> str:
return f'"{value}"'
Expand Down
27 changes: 20 additions & 7 deletions autorest/codegen/models/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# license information.
# --------------------------------------------------------------------------
from enum import Enum
from typing import Dict, Optional, Set
from typing import Dict, Optional, Set, Tuple, Union


class ImportType(str, Enum):
Expand All @@ -21,20 +21,27 @@ class TypingSection(str, Enum):

class FileImport:
def __init__(
self, imports: Dict[TypingSection, Dict[ImportType, Dict[str, Set[Optional[str]]]]] = None
self,
imports: Dict[
TypingSection,
Dict[ImportType, Dict[str, Set[Optional[Union[str, Tuple[str, str]]]]]]
] = None
) -> None:
# Basic implementation
# First level dict: TypingSection
# Second level dict: ImportType
# Third level dict: the package name.
# Fourth level set: None if this import is a "import", the name to import if it's a "from"
self._imports: Dict[TypingSection, Dict[ImportType, Dict[str, Set[Optional[str]]]]] = imports or dict()
self._imports: Dict[
TypingSection,
Dict[ImportType, Dict[str, Set[Optional[Union[str, Tuple[str, str]]]]]]
] = imports or dict()

def _add_import(
self,
from_section: str,
import_type: ImportType,
name_import: Optional[str] = None,
name_import: Optional[Union[str, Tuple[str, str]]] = None,
typing_section: TypingSection = TypingSection.REGULAR
) -> None:
self._imports.setdefault(
Expand All @@ -50,11 +57,14 @@ def add_from_import(
from_section: str,
name_import: str,
import_type: ImportType,
typing_section: TypingSection = TypingSection.REGULAR
typing_section: TypingSection = TypingSection.REGULAR,
alias: Optional[str] = None,
) -> None:
"""Add an import to this import block.
"""
self._add_import(from_section, import_type, name_import, typing_section)
self._add_import(
from_section, import_type, (name_import, alias) if alias else name_import, typing_section
)

def add_import(
self,
Expand All @@ -66,7 +76,10 @@ def add_import(
self._add_import(name_import, import_type, None, typing_section)

@property
def imports(self) -> Dict[TypingSection, Dict[ImportType, Dict[str, Set[Optional[str]]]]]:
def imports(self) -> Dict[
TypingSection,
Dict[ImportType, Dict[str, Set[Optional[Union[str, Tuple[str, str]]]]]]
]:
return self._imports

def merge(self, file_import: "FileImport") -> None:
Expand Down
2 changes: 1 addition & 1 deletion autorest/codegen/models/object_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def type_annotation(self) -> str:

@property
def operation_type_annotation(self) -> str:
return f'"models.{self.name}"'
return f'"_models.{self.name}"'

@property
def docstring_type(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion autorest/codegen/models/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def default_exception(self) -> Optional[str]:
return None
excep_schema = default_excp[0].schema
if isinstance(excep_schema, ObjectSchema):
return f"models.{excep_schema.name}"
return f"_models.{excep_schema.name}"
# in this case, it's just an AnySchema
return "\'object\'"

Expand Down
4 changes: 2 additions & 2 deletions autorest/codegen/models/operation_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ def imports(self, async_mode: bool, has_schemas: bool) -> FileImport:
)
if has_schemas:
if async_mode:
file_import.add_from_import("...", "models", ImportType.LOCAL)
file_import.add_from_import("...", "models", ImportType.LOCAL, alias="_models")
else:
file_import.add_from_import("..", "models", ImportType.LOCAL)
file_import.add_from_import("..", "models", ImportType.LOCAL, alias="_models")
return file_import

@property
Expand Down
2 changes: 1 addition & 1 deletion autorest/codegen/models/parameter_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,4 +141,4 @@ def build_flattened_object(self) -> str:
]
)
object_schema = cast(ObjectSchema, self.body[0].schema)
return f"{self.body[0].serialized_name} = models.{object_schema.name}({parameter_string})"
return f"{self.body[0].serialized_name} = _models.{object_schema.name}({parameter_string})"
16 changes: 11 additions & 5 deletions autorest/codegen/serializers/import_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,36 @@
# license information.
# --------------------------------------------------------------------------
from copy import deepcopy
from typing import Dict, Set, Optional, List
from typing import Dict, Set, Optional, List, Tuple, Union
from ..models.imports import ImportType, FileImport, TypingSection

def _serialize_package(package_name: str, module_list: Set[Optional[str]], delimiter: str) -> str:
def _serialize_package(
package_name: str, module_list: Set[Optional[Union[str, Tuple[str, str]]]], delimiter: str
) -> str:
buffer = []
if None in module_list:
buffer.append(f"import {package_name}")
if module_list != {None}:
buffer.append(
"from {} import {}".format(
package_name, ", ".join(sorted([mod for mod in module_list if mod is not None]))
package_name, ", ".join(sorted([
mod if isinstance(mod, str) else f"{mod[0]} as {mod[1]}" for mod in module_list if mod is not None
]))
)
)
return delimiter.join(buffer)

def _serialize_type(import_type_dict: Dict[str, Set[Optional[str]]], delimiter: str) -> str:
def _serialize_type(import_type_dict: Dict[str, Set[Optional[Union[str, Tuple[str, str]]]]], delimiter: str) -> str:
"""Serialize a given import type."""
import_list = []
for package_name in sorted(list(import_type_dict.keys())):
module_list = import_type_dict[package_name]
import_list.append(_serialize_package(package_name, module_list, delimiter))
return delimiter.join(import_list)

def _get_import_clauses(imports: Dict[ImportType, Dict[str, Set[Optional[str]]]], delimiter: str) -> List[str]:
def _get_import_clauses(
imports: Dict[ImportType, Dict[str, Set[Optional[Union[str, Tuple[str, str]]]]]], delimiter: str
) -> List[str]:
import_clause = []
for import_type in ImportType:
if import_type in imports:
Expand Down
7 changes: 5 additions & 2 deletions autorest/codegen/serializers/metadata_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# --------------------------------------------------------------------------
import copy
import json
from typing import List, Optional, Set, Tuple, Dict
from typing import List, Optional, Set, Tuple, Dict, Union
from jinja2 import Environment
from ..models import (
CodeModel,
Expand All @@ -26,7 +26,10 @@ def _correct_credential_parameter(global_parameters: ParameterList, async_mode:
credential_param.schema = TokenCredentialSchema(async_mode=async_mode)

def _json_serialize_imports(
imports: Dict[TypingSection, Dict[ImportType, Dict[str, Set[Optional[str]]]]]
imports: Dict[
TypingSection,
Dict[ImportType, Dict[str, Set[Optional[Union[str, Tuple[str, str]]]]]]
]
):
if not imports:
return None
Expand Down
2 changes: 1 addition & 1 deletion autorest/codegen/templates/operation_tools.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ error_map = {
{% endif %}
{% for excep in operation.status_code_exceptions %}
{% for status_code in excep.status_codes %}
{% set error_model = ", model=self._deserialize(models." + excep.serialization_type + ", response)" if excep.is_exception else "" %}
{% set error_model = ", model=self._deserialize(_models." + excep.serialization_type + ", response)" if excep.is_exception else "" %}
{% set error_format = ", error_format=ARMErrorFormat" if code_model.options['azure_arm'] else "" %}
{% if status_code == 401 %}
401: lambda response: ClientAuthenticationError(response=response{{ error_model }}{{ error_format }}),
Expand Down
2 changes: 1 addition & 1 deletion autorest/codegen/templates/operations_container.py.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class {{ operation_group.class_name }}{{ object_base_class }}:
"""

{% if code_model.schemas %}
models = models
models = _models

{% endif %}
def __init__(self, client, config, serializer, deserializer){{ return_none_type_annotation }}:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from azure.core.pipeline.transport import AsyncHttpResponse, HttpRequest
from azure.core.tracing.decorator_async import distributed_trace_async

from ... import models
from ... import models as _models

T = TypeVar('T')
ClsType = Optional[Callable[[PipelineResponse[HttpRequest, AsyncHttpResponse], T, Dict[str, Any]], Any]]
Expand All @@ -33,7 +33,7 @@ class DurationOperations:
:param deserializer: An object model deserializer.
"""

models = models
models = _models

def __init__(self, client, config, serializer, deserializer) -> None:
self._client = client
Expand Down Expand Up @@ -76,7 +76,7 @@ async def get_null(

if response.status_code not in [200]:
map_error(status_code=response.status_code, response=response, error_map=error_map)
error = self._deserialize(models.Error, response)
error = self._deserialize(_models.Error, response)
raise HttpResponseError(response=response, model=error)

deserialized = self._deserialize('duration', pipeline_response)
Expand Down Expand Up @@ -130,7 +130,7 @@ async def put_positive_duration(

if response.status_code not in [200]:
map_error(status_code=response.status_code, response=response, error_map=error_map)
error = self._deserialize(models.Error, response)
error = self._deserialize(_models.Error, response)
raise HttpResponseError(response=response, model=error)

if cls:
Expand Down Expand Up @@ -173,7 +173,7 @@ async def get_positive_duration(

if response.status_code not in [200]:
map_error(status_code=response.status_code, response=response, error_map=error_map)
error = self._deserialize(models.Error, response)
error = self._deserialize(_models.Error, response)
raise HttpResponseError(response=response, model=error)

deserialized = self._deserialize('duration', pipeline_response)
Expand Down Expand Up @@ -219,7 +219,7 @@ async def get_invalid(

if response.status_code not in [200]:
map_error(status_code=response.status_code, response=response, error_map=error_map)
error = self._deserialize(models.Error, response)
error = self._deserialize(_models.Error, response)
raise HttpResponseError(response=response, model=error)

deserialized = self._deserialize('duration', pipeline_response)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from azure.core.pipeline.transport import HttpRequest, HttpResponse
from azure.core.tracing.decorator import distributed_trace

from .. import models
from .. import models as _models

if TYPE_CHECKING:
# pylint: disable=unused-import,ungrouped-imports
Expand All @@ -37,7 +37,7 @@ class DurationOperations(object):
:param deserializer: An object model deserializer.
"""

models = models
models = _models

def __init__(self, client, config, serializer, deserializer):
self._client = client
Expand Down Expand Up @@ -81,7 +81,7 @@ def get_null(

if response.status_code not in [200]:
map_error(status_code=response.status_code, response=response, error_map=error_map)
error = self._deserialize(models.Error, response)
error = self._deserialize(_models.Error, response)
raise HttpResponseError(response=response, model=error)

deserialized = self._deserialize('duration', pipeline_response)
Expand Down Expand Up @@ -136,7 +136,7 @@ def put_positive_duration(

if response.status_code not in [200]:
map_error(status_code=response.status_code, response=response, error_map=error_map)
error = self._deserialize(models.Error, response)
error = self._deserialize(_models.Error, response)
raise HttpResponseError(response=response, model=error)

if cls:
Expand Down Expand Up @@ -180,7 +180,7 @@ def get_positive_duration(

if response.status_code not in [200]:
map_error(status_code=response.status_code, response=response, error_map=error_map)
error = self._deserialize(models.Error, response)
error = self._deserialize(_models.Error, response)
raise HttpResponseError(response=response, model=error)

deserialized = self._deserialize('duration', pipeline_response)
Expand Down Expand Up @@ -227,7 +227,7 @@ def get_invalid(

if response.status_code not in [200]:
map_error(status_code=response.status_code, response=response, error_map=error_map)
error = self._deserialize(models.Error, response)
error = self._deserialize(_models.Error, response)
raise HttpResponseError(response=response, model=error)

deserialized = self._deserialize('duration', pipeline_response)
Expand Down
Loading