Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions ChangeLog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
# Change Log

### 2020-xx-xx - 5.4.3
Autorest core version: 3.0.6320

Modelerfour version: 4.15.421

**Bug Fixes**

- Fix conflict for model deserialization when operation has input param with name `models` #819

### 2020-11-09 - 5.4.2
Autorest core version: 3.0.6320

Expand Down
2 changes: 1 addition & 1 deletion autorest/codegen/models/enum_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def operation_type_annotation(self) -> str:
:return: The type annotation for this schema
:rtype: str
"""
return f'Union[{self.enum_type.type_annotation}, "models.{self.name}"]'
return f'Union[{self.enum_type.type_annotation}, "_models.{self.name}"]'

def get_declaration(self, value: Any) -> str:
return f'"{value}"'
Expand Down
27 changes: 20 additions & 7 deletions autorest/codegen/models/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# license information.
# --------------------------------------------------------------------------
from enum import Enum
from typing import Dict, Optional, Set
from typing import Dict, Optional, Set, Tuple, Union


class ImportType(str, Enum):
Expand All @@ -21,20 +21,27 @@ class TypingSection(str, Enum):

class FileImport:
def __init__(
self, imports: Dict[TypingSection, Dict[ImportType, Dict[str, Set[Optional[str]]]]] = None
self,
imports: Dict[
TypingSection,
Dict[ImportType, Dict[str, Set[Optional[Union[str, Tuple[str, str]]]]]]
] = None
) -> None:
# Basic implementation
# First level dict: TypingSection
# Second level dict: ImportType
# Third level dict: the package name.
# Fourth level set: None if this import is a "import", the name to import if it's a "from"
self._imports: Dict[TypingSection, Dict[ImportType, Dict[str, Set[Optional[str]]]]] = imports or dict()
self._imports: Dict[
TypingSection,
Dict[ImportType, Dict[str, Set[Optional[Union[str, Tuple[str, str]]]]]]
] = imports or dict()

def _add_import(
self,
from_section: str,
import_type: ImportType,
name_import: Optional[str] = None,
name_import: Optional[Union[str, Tuple[str, str]]] = None,
typing_section: TypingSection = TypingSection.REGULAR
) -> None:
self._imports.setdefault(
Expand All @@ -50,11 +57,14 @@ def add_from_import(
from_section: str,
name_import: str,
import_type: ImportType,
typing_section: TypingSection = TypingSection.REGULAR
typing_section: TypingSection = TypingSection.REGULAR,
alias: Optional[str] = None,
) -> None:
"""Add an import to this import block.
"""
self._add_import(from_section, import_type, name_import, typing_section)
self._add_import(
from_section, import_type, (name_import, alias) if alias else name_import, typing_section
)

def add_import(
self,
Expand All @@ -66,7 +76,10 @@ def add_import(
self._add_import(name_import, import_type, None, typing_section)

@property
def imports(self) -> Dict[TypingSection, Dict[ImportType, Dict[str, Set[Optional[str]]]]]:
def imports(self) -> Dict[
TypingSection,
Dict[ImportType, Dict[str, Set[Optional[Union[str, Tuple[str, str]]]]]]
]:
return self._imports

def merge(self, file_import: "FileImport") -> None:
Expand Down
2 changes: 1 addition & 1 deletion autorest/codegen/models/object_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def type_annotation(self) -> str:

@property
def operation_type_annotation(self) -> str:
return f'"models.{self.name}"'
return f'"_models.{self.name}"'

@property
def docstring_type(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion autorest/codegen/models/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def default_exception(self) -> Optional[str]:
return None
excep_schema = default_excp[0].schema
if isinstance(excep_schema, ObjectSchema):
return f"models.{excep_schema.name}"
return f"_models.{excep_schema.name}"
# in this case, it's just an AnySchema
return "\'object\'"

Expand Down
4 changes: 2 additions & 2 deletions autorest/codegen/models/operation_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ def imports(self, async_mode: bool, has_schemas: bool) -> FileImport:
)
if has_schemas:
if async_mode:
file_import.add_from_import("...", "models", ImportType.LOCAL)
file_import.add_from_import("...", "models", ImportType.LOCAL, alias="_models")
else:
file_import.add_from_import("..", "models", ImportType.LOCAL)
file_import.add_from_import("..", "models", ImportType.LOCAL, alias="_models")
return file_import

@property
Expand Down
2 changes: 1 addition & 1 deletion autorest/codegen/models/parameter_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,4 +141,4 @@ def build_flattened_object(self) -> str:
]
)
object_schema = cast(ObjectSchema, self.body[0].schema)
return f"{self.body[0].serialized_name} = models.{object_schema.name}({parameter_string})"
return f"{self.body[0].serialized_name} = _models.{object_schema.name}({parameter_string})"
16 changes: 11 additions & 5 deletions autorest/codegen/serializers/import_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,36 @@
# license information.
# --------------------------------------------------------------------------
from copy import deepcopy
from typing import Dict, Set, Optional, List
from typing import Dict, Set, Optional, List, Tuple, Union
from ..models.imports import ImportType, FileImport, TypingSection

def _serialize_package(package_name: str, module_list: Set[Optional[str]], delimiter: str) -> str:
def _serialize_package(
package_name: str, module_list: Set[Optional[Union[str, Tuple[str, str]]]], delimiter: str
) -> str:
buffer = []
if None in module_list:
buffer.append(f"import {package_name}")
if module_list != {None}:
buffer.append(
"from {} import {}".format(
package_name, ", ".join(sorted([mod for mod in module_list if mod is not None]))
package_name, ", ".join(sorted([
mod if isinstance(mod, str) else f"{mod[0]} as {mod[1]}" for mod in module_list if mod is not None
]))
)
)
return delimiter.join(buffer)

def _serialize_type(import_type_dict: Dict[str, Set[Optional[str]]], delimiter: str) -> str:
def _serialize_type(import_type_dict: Dict[str, Set[Optional[Union[str, Tuple[str, str]]]]], delimiter: str) -> str:
"""Serialize a given import type."""
import_list = []
for package_name in sorted(list(import_type_dict.keys())):
module_list = import_type_dict[package_name]
import_list.append(_serialize_package(package_name, module_list, delimiter))
return delimiter.join(import_list)

def _get_import_clauses(imports: Dict[ImportType, Dict[str, Set[Optional[str]]]], delimiter: str) -> List[str]:
def _get_import_clauses(
imports: Dict[ImportType, Dict[str, Set[Optional[Union[str, Tuple[str, str]]]]]], delimiter: str
) -> List[str]:
import_clause = []
for import_type in ImportType:
if import_type in imports:
Expand Down
7 changes: 5 additions & 2 deletions autorest/codegen/serializers/metadata_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# --------------------------------------------------------------------------
import copy
import json
from typing import List, Optional, Set, Tuple, Dict
from typing import List, Optional, Set, Tuple, Dict, Union
from jinja2 import Environment
from ..models import (
CodeModel,
Expand All @@ -26,7 +26,10 @@ def _correct_credential_parameter(global_parameters: ParameterList, async_mode:
credential_param.schema = TokenCredentialSchema(async_mode=async_mode)

def _json_serialize_imports(
imports: Dict[TypingSection, Dict[ImportType, Dict[str, Set[Optional[str]]]]]
imports: Dict[
TypingSection,
Dict[ImportType, Dict[str, Set[Optional[Union[str, Tuple[str, str]]]]]]
]
):
if not imports:
return None
Expand Down
2 changes: 1 addition & 1 deletion autorest/codegen/templates/operation_tools.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ error_map = {
{% endif %}
{% for excep in operation.status_code_exceptions %}
{% for status_code in excep.status_codes %}
{% set error_model = ", model=self._deserialize(models." + excep.serialization_type + ", response)" if excep.is_exception else "" %}
{% set error_model = ", model=self._deserialize(_models." + excep.serialization_type + ", response)" if excep.is_exception else "" %}
{% set error_format = ", error_format=ARMErrorFormat" if code_model.options['azure_arm'] else "" %}
{% if status_code == 401 %}
401: lambda response: ClientAuthenticationError(response=response{{ error_model }}{{ error_format }}),
Expand Down
2 changes: 1 addition & 1 deletion autorest/codegen/templates/operations_container.py.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class {{ operation_group.class_name }}{{ object_base_class }}:
"""

{% if code_model.schemas %}
models = models
models = _models

{% endif %}
def __init__(self, client, config, serializer, deserializer){{ return_none_type_annotation }}:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from azure.core.pipeline.transport import AsyncHttpResponse, HttpRequest
from azure.core.tracing.decorator_async import distributed_trace_async

from ... import models
from ... import models as _models

T = TypeVar('T')
ClsType = Optional[Callable[[PipelineResponse[HttpRequest, AsyncHttpResponse], T, Dict[str, Any]], Any]]
Expand All @@ -33,7 +33,7 @@ class DurationOperations:
:param deserializer: An object model deserializer.
"""

models = models
models = _models

def __init__(self, client, config, serializer, deserializer) -> None:
self._client = client
Expand Down Expand Up @@ -76,7 +76,7 @@ async def get_null(

if response.status_code not in [200]:
map_error(status_code=response.status_code, response=response, error_map=error_map)
error = self._deserialize(models.Error, response)
error = self._deserialize(_models.Error, response)
raise HttpResponseError(response=response, model=error)

deserialized = self._deserialize('duration', pipeline_response)
Expand Down Expand Up @@ -130,7 +130,7 @@ async def put_positive_duration(

if response.status_code not in [200]:
map_error(status_code=response.status_code, response=response, error_map=error_map)
error = self._deserialize(models.Error, response)
error = self._deserialize(_models.Error, response)
raise HttpResponseError(response=response, model=error)

if cls:
Expand Down Expand Up @@ -173,7 +173,7 @@ async def get_positive_duration(

if response.status_code not in [200]:
map_error(status_code=response.status_code, response=response, error_map=error_map)
error = self._deserialize(models.Error, response)
error = self._deserialize(_models.Error, response)
raise HttpResponseError(response=response, model=error)

deserialized = self._deserialize('duration', pipeline_response)
Expand Down Expand Up @@ -219,7 +219,7 @@ async def get_invalid(

if response.status_code not in [200]:
map_error(status_code=response.status_code, response=response, error_map=error_map)
error = self._deserialize(models.Error, response)
error = self._deserialize(_models.Error, response)
raise HttpResponseError(response=response, model=error)

deserialized = self._deserialize('duration', pipeline_response)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from azure.core.pipeline.transport import HttpRequest, HttpResponse
from azure.core.tracing.decorator import distributed_trace

from .. import models
from .. import models as _models

if TYPE_CHECKING:
# pylint: disable=unused-import,ungrouped-imports
Expand All @@ -37,7 +37,7 @@ class DurationOperations(object):
:param deserializer: An object model deserializer.
"""

models = models
models = _models

def __init__(self, client, config, serializer, deserializer):
self._client = client
Expand Down Expand Up @@ -81,7 +81,7 @@ def get_null(

if response.status_code not in [200]:
map_error(status_code=response.status_code, response=response, error_map=error_map)
error = self._deserialize(models.Error, response)
error = self._deserialize(_models.Error, response)
raise HttpResponseError(response=response, model=error)

deserialized = self._deserialize('duration', pipeline_response)
Expand Down Expand Up @@ -136,7 +136,7 @@ def put_positive_duration(

if response.status_code not in [200]:
map_error(status_code=response.status_code, response=response, error_map=error_map)
error = self._deserialize(models.Error, response)
error = self._deserialize(_models.Error, response)
raise HttpResponseError(response=response, model=error)

if cls:
Expand Down Expand Up @@ -180,7 +180,7 @@ def get_positive_duration(

if response.status_code not in [200]:
map_error(status_code=response.status_code, response=response, error_map=error_map)
error = self._deserialize(models.Error, response)
error = self._deserialize(_models.Error, response)
raise HttpResponseError(response=response, model=error)

deserialized = self._deserialize('duration', pipeline_response)
Expand Down Expand Up @@ -227,7 +227,7 @@ def get_invalid(

if response.status_code not in [200]:
map_error(status_code=response.status_code, response=response, error_map=error_map)
error = self._deserialize(models.Error, response)
error = self._deserialize(_models.Error, response)
raise HttpResponseError(response=response, model=error)

deserialized = self._deserialize('duration', pipeline_response)
Expand Down
Loading