Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 20 additions & 14 deletions autorest/codegen/serializers/general_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,21 @@
# --------------------------------------------------------------------------
from jinja2 import Environment
from .import_serializer import FileImportSerializer, TypingSection
from ..models import FileImport, ImportType, CodeModel, TokenCredentialSchema
from ..models import FileImport, ImportType, CodeModel, TokenCredentialSchema, ParameterList


def config_imports(code_model, global_parameters: ParameterList, async_mode: bool) -> FileImport:
file_import = FileImport()
file_import.add_from_import("azure.core.configuration", "Configuration", ImportType.AZURECORE)
file_import.add_from_import("azure.core.pipeline", "policies", ImportType.AZURECORE)
file_import.add_from_import("typing", "Any", ImportType.STDLIB, TypingSection.CONDITIONAL)
if code_model.options["package_version"]:
file_import.add_from_import(".._version" if async_mode else "._version", "VERSION", ImportType.LOCAL)
for gp in global_parameters:
file_import.merge(gp.imports())
if code_model.options["azure_arm"]:
file_import.add_from_import("azure.mgmt.core.policies", "ARMHttpLoggingPolicy", ImportType.AZURECORE)
return file_import


class GeneralSerializer:
Expand Down Expand Up @@ -53,18 +67,6 @@ def _service_client_imports() -> FileImport:
)

def serialize_config_file(self) -> str:
def _config_imports(async_mode: bool) -> FileImport:
file_import = FileImport()
file_import.add_from_import("azure.core.configuration", "Configuration", ImportType.AZURECORE)
file_import.add_from_import("azure.core.pipeline", "policies", ImportType.AZURECORE)
file_import.add_from_import("typing", "Any", ImportType.STDLIB, TypingSection.CONDITIONAL)
if self.code_model.options["package_version"]:
file_import.add_from_import(".._version" if async_mode else "._version", "VERSION", ImportType.LOCAL)
for gp in self.code_model.global_parameters:
file_import.merge(gp.imports())
if self.code_model.options["azure_arm"]:
file_import.add_from_import("azure.mgmt.core.policies", "ARMHttpLoggingPolicy", ImportType.AZURECORE)
return file_import

package_name = self.code_model.options['package_name']
if package_name and package_name.startswith("azure-"):
Expand All @@ -81,7 +83,11 @@ def _config_imports(async_mode: bool) -> FileImport:
return template.render(
code_model=self.code_model,
async_mode=self.async_mode,
imports=FileImportSerializer(_config_imports(self.async_mode), is_python_3_file=self.async_mode),
imports=FileImportSerializer(
config_imports(
self.code_model, self.code_model.global_parameters, self.async_mode
), is_python_3_file=self.async_mode
),
sdk_moniker=sdk_moniker
)

Expand Down
71 changes: 57 additions & 14 deletions autorest/codegen/serializers/metadata_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
import json
from typing import List, Optional, Set, Tuple, Dict, Union
from jinja2 import Environment
from .general_serializer import config_imports
from ..models import (
CodeModel,
Operation,
FileImport,
OperationGroup,
LROOperation,
PagingOperation,
Expand Down Expand Up @@ -53,6 +54,14 @@ def _json_serialize_imports(
json_serialize_imports[typing_section_key] = json_import_type_dictionary
return json.dumps(json_serialize_imports)

def _mixin_imports(mixin_operation_group: Optional[OperationGroup]) -> Tuple[Optional[str], Optional[str]]:
if not mixin_operation_group:
return None, None

sync_mixin_imports = mixin_operation_group.imports(async_mode=False, has_schemas=False)
async_mixin_imports = mixin_operation_group.imports(async_mode=True, has_schemas=False)

return _json_serialize_imports(sync_mixin_imports.imports), _json_serialize_imports(async_mixin_imports.imports)

class MetadataSerializer:
def __init__(self, code_model: CodeModel, env: Environment) -> None:
Expand Down Expand Up @@ -85,6 +94,32 @@ def _make_async_copy_of_global_parameters(self) -> ParameterList:
_correct_credential_parameter(global_parameters, True)
return global_parameters

def _service_client_imports(
self,
global_parameters: ParameterList,
mixin_operation_group: Optional[OperationGroup],
async_mode: bool
) -> str:
file_import = FileImport()
for gp in global_parameters:
file_import.merge(gp.imports())
file_import.add_from_import("azure.profiles", "KnownProfiles", import_type=ImportType.AZURECORE)
file_import.add_from_import("azure.profiles", "ProfileDefinition", import_type=ImportType.AZURECORE)
file_import.add_from_import(
"azure.profiles.multiapiclient", "MultiApiClientMixin", import_type=ImportType.AZURECORE
)
file_import.add_from_import("._configuration", f"{self.code_model.class_name}Configuration", ImportType.LOCAL)
# api_version and potentially base_url require Optional typing
file_import.add_from_import("typing", "Optional", ImportType.STDLIB, TypingSection.CONDITIONAL)
if mixin_operation_group:
file_import.add_from_import(
"._operations_mixin", f"{self.code_model.class_name}OperationsMixin", ImportType.LOCAL
)

file_import.merge(self.code_model.service_client.imports(self.code_model, async_mode=async_mode))
return _json_serialize_imports(file_import.imports)


def serialize(self) -> str:
def _is_lro(operation):
return isinstance(operation, LROOperation)
Expand All @@ -97,13 +132,9 @@ def _is_paging(operation):
for operation_group in self.code_model.operation_groups if operation_group.is_empty_operation_group),
None
)
mixin_operations: List[Operation] = []
sync_mixin_imports = None
async_mixin_imports = None
if mixin_operation_group:
mixin_operations = mixin_operation_group.operations
sync_mixin_imports = mixin_operation_group.imports(async_mode=False, has_schemas=False)
async_mixin_imports = mixin_operation_group.imports(async_mode=True, has_schemas=False)
mixin_operations = mixin_operation_group.operations if mixin_operation_group else []
sync_mixin_imports, async_mixin_imports = _mixin_imports(mixin_operation_group)

chosen_version, total_api_version_list = self._choose_api_version()

# we separate out async and sync for the case of credentials.
Expand All @@ -119,7 +150,17 @@ def _is_paging(operation):
_correct_credential_parameter(self.code_model.global_parameters, False)
async_global_parameters = self._make_async_copy_of_global_parameters()

sync_client_imports = self._service_client_imports(
self.code_model.global_parameters, mixin_operation_group, async_mode=False
)
async_client_imports = self._service_client_imports(
async_global_parameters, mixin_operation_group, async_mode=True
)

template = self.env.get_template("metadata.json.jinja2")

# setting to true, because for multiapi we always generate with a version file with version 0.1.0
self.code_model.options['package_version'] = '0.1.0'
return template.render(
chosen_version=chosen_version,
total_api_version_list=total_api_version_list,
Expand All @@ -131,12 +172,14 @@ def _is_paging(operation):
is_lro=_is_lro,
is_paging=_is_paging,
str=str,
sync_mixin_imports=(
_json_serialize_imports(sync_mixin_imports.imports)
if sync_mixin_imports else None
sync_mixin_imports=sync_mixin_imports,
async_mixin_imports=async_mixin_imports,
sync_client_imports=sync_client_imports,
async_client_imports=async_client_imports,
sync_config_imports=_json_serialize_imports(
config_imports(self.code_model, self.code_model.global_parameters, async_mode=False).imports
),
async_mixin_imports=(
_json_serialize_imports(async_mixin_imports.imports)
if async_mixin_imports else None
async_config_imports=_json_serialize_imports(
config_imports(self.code_model, async_global_parameters, async_mode=True).imports
)
)
102 changes: 54 additions & 48 deletions autorest/codegen/templates/metadata.json.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
"custom_base_url": {{ (keywords.escape_str(code_model.custom_base_url) if code_model.custom_base_url else None) | tojson }},
"azure_arm": {{ code_model.options["azure_arm"] | tojson }},
"has_lro_operations": {{ code_model.has_lro_operations | tojson }},
"client_side_validation": {{ code_model.options["client_side_validation"] | tojson }}
"client_side_validation": {{ code_model.options["client_side_validation"] | tojson }},
"sync_imports": {{ sync_client_imports | tojson }},
"async_imports": {{ async_client_imports | tojson }}
},
"global_parameters": {
"sync": {
Expand All @@ -28,7 +30,7 @@
"async": {
{% for gp in async_global_parameters.method %}
{{ gp.serialized_name | tojson }}: {
"signature": {{ gp.sync_method_signature | tojson }},
"signature": {{ (gp.async_method_signature + ",") | tojson }},
"description": {{ gp.description | tojson }},
"docstring_type": {{ gp.docstring_type | tojson }},
"required": {{ gp.required | tojson }}
Expand All @@ -47,57 +49,61 @@
"credential_scopes": {{ code_model.options['credential_scopes'] | tojson }},
"credential_default_policy_type": {{ code_model.options['credential_default_policy_type'] | tojson }},
"credential_default_policy_type_has_async_version": {{ code_model.options['credential_default_policy_type_has_async_version'] | tojson }},
"credential_key_header_name": {{ code_model.options['credential_key_header_name'] | tojson }}
"credential_key_header_name": {{ code_model.options['credential_key_header_name'] | tojson }},
"sync_imports": {{ sync_config_imports | tojson }},
"async_imports": {{ async_config_imports | tojson }}
},
"operation_groups": {
{% for operation_group in code_model.operation_groups | rejectattr('is_empty_operation_group') %}
{{ operation_group.name | tojson }}: {{ operation_group.class_name | tojson }}{{ "," if not loop.last else "" }}
{% endfor %}
},
"operation_mixins": {
{% for operation in mixin_operations %}
{% set operation_name = "begin_" + operation.name if is_lro(operation) else operation.name %}
{{ operation_name | tojson }} : {
"sync": {
{% if is_lro(operation) and is_paging(operation) %}
{% from "lro_paging_operation.py.jinja2" import operation_docstring with context %}
{% set sync_return_type_wrapper = [operation.get_poller(async_mode=False), operation.get_pager(async_mode=False)] %}
{% elif is_lro(operation) %}
{% from "lro_operation.py.jinja2" import operation_docstring with context %}
{% set sync_return_type_wrapper = [operation.get_poller(async_mode=False)] %}
{% elif is_paging(operation) %}
{% from "paging_operation.py.jinja2" import operation_docstring with context %}
{% set sync_return_type_wrapper = [operation.get_pager(async_mode=False)] %}
{% else %}
{% from "operation.py.jinja2" import operation_docstring with context %}
{% set sync_return_type_wrapper = "" %}
{% endif %}
"signature": {{ op_tools.method_signature(operation, operation_name, False, False, sync_return_type_wrapper) | tojson }},
"doc": {{ operation_docstring(async_mode=False) | tojson }}
},
"async": {
{% set coroutine = False if is_paging(operation) else True %}
"coroutine": {{ coroutine | tojson }},
{% if is_lro(operation) and is_paging(operation) %}
{% from "lro_paging_operation.py.jinja2" import operation_docstring with context %}
{% set async_return_type_wrapper = [operation.get_poller(async_mode=True), operation.get_pager(async_mode=True)] %}
{% elif is_lro(operation) %}
{% from "lro_operation.py.jinja2" import operation_docstring with context %}
{% set async_return_type_wrapper = [operation.get_poller(async_mode=True)] %}
{% elif is_paging(operation) %}
{% from "paging_operation.py.jinja2" import operation_docstring with context %}
{% set async_return_type_wrapper = [operation.get_pager(async_mode=True)] %}
{% else %}
{% from "operation.py.jinja2" import operation_docstring with context %}
{% set async_return_type_wrapper = "" %}
{% endif %}
"signature": {{ op_tools.method_signature(operation, operation_name, True, coroutine, async_return_type_wrapper) | tojson }},
"doc": {{ operation_docstring(async_mode=True) | tojson }}
},
"call": {{ operation.parameters.method | map(attribute="serialized_name") | join(', ') | tojson }}
}{{ "," if not loop.last else "" }}
{% endfor %}
},
"sync_imports": {{ str(sync_mixin_imports) | tojson }},
"async_imports": {{ str(async_mixin_imports) | tojson }}
"sync_imports": {{ str(sync_mixin_imports) | tojson }},
"async_imports": {{ str(async_mixin_imports) | tojson }},
"operations": {
{% for operation in mixin_operations %}
{% set operation_name = "begin_" + operation.name if is_lro(operation) else operation.name %}
{{ operation_name | tojson }} : {
"sync": {
{% if is_lro(operation) and is_paging(operation) %}
{% from "lro_paging_operation.py.jinja2" import operation_docstring with context %}
{% set sync_return_type_wrapper = [operation.get_poller(async_mode=False), operation.get_pager(async_mode=False)] %}
{% elif is_lro(operation) %}
{% from "lro_operation.py.jinja2" import operation_docstring with context %}
{% set sync_return_type_wrapper = [operation.get_poller(async_mode=False)] %}
{% elif is_paging(operation) %}
{% from "paging_operation.py.jinja2" import operation_docstring with context %}
{% set sync_return_type_wrapper = [operation.get_pager(async_mode=False)] %}
{% else %}
{% from "operation.py.jinja2" import operation_docstring with context %}
{% set sync_return_type_wrapper = "" %}
{% endif %}
"signature": {{ op_tools.method_signature(operation, operation_name, False, False, sync_return_type_wrapper) | tojson }},
"doc": {{ operation_docstring(async_mode=False) | tojson }}
},
"async": {
{% set coroutine = False if is_paging(operation) else True %}
"coroutine": {{ coroutine | tojson }},
{% if is_lro(operation) and is_paging(operation) %}
{% from "lro_paging_operation.py.jinja2" import operation_docstring with context %}
{% set async_return_type_wrapper = [operation.get_poller(async_mode=True), operation.get_pager(async_mode=True)] %}
{% elif is_lro(operation) %}
{% from "lro_operation.py.jinja2" import operation_docstring with context %}
{% set async_return_type_wrapper = [operation.get_poller(async_mode=True)] %}
{% elif is_paging(operation) %}
{% from "paging_operation.py.jinja2" import operation_docstring with context %}
{% set async_return_type_wrapper = [operation.get_pager(async_mode=True)] %}
{% else %}
{% from "operation.py.jinja2" import operation_docstring with context %}
{% set async_return_type_wrapper = "" %}
{% endif %}
"signature": {{ op_tools.method_signature(operation, operation_name, True, coroutine, async_return_type_wrapper) | tojson }},
"doc": {{ operation_docstring(async_mode=True) | tojson }}
},
"call": {{ operation.parameters.method | map(attribute="serialized_name") | join(', ') | tojson }}
}{{ "," if not loop.last else "" }}
{% endfor %}
}
}
}
7 changes: 7 additions & 0 deletions autorest/multiapi/models/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
# license information.
# --------------------------------------------------------------------------
import sys
import json
from typing import Any, Dict, List
from pathlib import Path
from .imports import FileImport

def _extract_version(metadata_json: Dict[str, Any], version_path: Path) -> str:
version = metadata_json['chosen_version']
Expand Down Expand Up @@ -34,8 +36,13 @@ def __init__(
self.base_url = default_version_metadata["client"]["base_url"]
self.description = default_version_metadata["client"]["description"]
self.client_side_validation = default_version_metadata["client"]["client_side_validation"]
self.default_version_metadata = default_version_metadata
self.version_path_to_metadata = version_path_to_metadata

def imports(self, async_mode: bool) -> FileImport:
imports_to_load = "async_imports" if async_mode else "sync_imports"
return FileImport(json.loads(self.default_version_metadata['client'][imports_to_load]))

@property
def custom_base_url_to_api_version(self) -> Dict[str, List[str]]:
custom_base_url_to_api_version: Dict[str, List[str]] = {}
Expand Down
4 changes: 3 additions & 1 deletion autorest/multiapi/models/code_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def __init__(
self.service_client = Client(self.azure_arm, default_version_metadata, version_path_to_metadata)
self.config = Config(default_version_metadata)
self.operation_mixin_group = OperationMixinGroup(version_path_to_metadata, default_api_version)
self.global_parameters = GlobalParameters(default_version_metadata["global_parameters"])
self.global_parameters = GlobalParameters(
default_version_metadata["global_parameters"], self.service_client.base_url
)
self.user_specified_default_api = user_specified_default_api

@property
Expand Down
7 changes: 7 additions & 0 deletions autorest/multiapi/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import json
from typing import Any, Dict
from .imports import FileImport

class Config:
def __init__(self, default_version_metadata: Dict[str, Any]):
Expand All @@ -14,3 +16,8 @@ def __init__(self, default_version_metadata: Dict[str, Any]):
default_version_metadata["config"]["credential_default_policy_type_has_async_version"]
)
self.credential_key_header_name = default_version_metadata["config"]["credential_key_header_name"]
self.default_version_metadata = default_version_metadata

def imports(self, async_mode: bool) -> FileImport:
imports_to_load = "async_imports" if async_mode else "sync_imports"
return FileImport(json.loads(self.default_version_metadata['config'][imports_to_load]))
Loading