Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion packages/autorest.python/ChangeLog.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

- Fix `x-ms-client-default` for model property #1937
- Added sub-namespace folder when generating samples #1920
- Optimize logic to find realted params in example files #1916
- Optimize logic to find related params in example files #1916
- Optimize default value for `api_version` for better compatibility of multiapi package #1934

### 2023-05-19 - 6.4.15

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,7 @@ def pop_kwargs_from_signature(self, builder: OperationType) -> List[str]:
)
else PopKwargType.SIMPLE,
check_client_input=not self.code_model.options["multiapi"],
in_operation=True,
)
cls_annotation = builder.cls_type_annotation(async_mode=self.async_mode)
pylint_disable = ""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,16 @@ def _get_client_models_value(models_dict_name: str) -> str:
og for og in self.client.operation_groups if not og.is_mixin
]
for og in operation_groups:
if og.code_model.options["multiapi"]:
api_version = (
f", '{og.api_versions[0]}'" if og.api_versions else ", None"
)
else:
api_version = ""
retval.extend(
[
f"self.{og.property_name} = {og.class_name}({og.pylint_disable}",
" self._client, self._config, self._serialize, self._deserialize",
f" self._client, self._config, self._serialize, self._deserialize{api_version}",
")",
]
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def pop_kwargs_from_signature(
pop_headers_kwarg: PopKwargType,
pop_params_kwarg: PopKwargType,
check_client_input: bool = False,
in_operation: bool = False,
) -> List[str]:
retval = []

Expand Down Expand Up @@ -142,9 +143,16 @@ def append_pop_kwarg(key: str, pop_type: PopKwargType) -> None:
if kwarg.location == ParameterLocation.HEADER
else "params"
)
if (
kwarg.client_name == "api_version"
and kwarg.code_model.options["multiapi"]
and in_operation
):
default_value = f"self._api_version or {default_value}"
default_value = (
f"_{kwarg_dict}.pop('{kwarg.wire_name}', {default_value})"
)

retval.append(
f"{kwarg.client_name}: {type_annot} = kwargs.pop('{kwarg.client_name}', "
+ f"{default_value})"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,20 @@ class {{ operation_group.class_name }}{{ base_class }}:{{ disable }}
self._config = input_args.pop(0) if input_args else kwargs.pop("config")
self._serialize = input_args.pop(0) if input_args else kwargs.pop("serializer")
self._deserialize = input_args.pop(0) if input_args else kwargs.pop("deserializer")
{% if code_model.options["multiapi"] %}
self._api_version = input_args.pop(0) if input_args else kwargs.pop("api_version")
{% endif %}
{{ check_abstract_methods() }}
{% elif operation_group.has_abstract_operations %}

def __init__(self){{ return_none_type_annotation }}:
{{ check_abstract_methods() }}
{% endif %}
{% if operation_group.is_mixin and code_model.options["multiapi"] %}
@property
def _api_version(self) -> str:
return self._config.api_version
{% endif %}
{% for operation in operation_group.operations if not operation.abstract %}

{% set request_builder = operation.request_builder %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ class {{ code_model.client.name }}({% if code_model.operation_mixin_group.mixin_
else:
raise ValueError("API version {} does not have operation group '{{ operation_group.name }}'".format(api_version))
self._config.api_version = api_version
return OperationClass(self._client, self._config, Serializer(self._models_dict(api_version)), Deserializer(self._models_dict(api_version)))
return OperationClass(self._client, self._config, Serializer(self._models_dict(api_version)), Deserializer(self._models_dict(api_version)), api_version)
{% endfor %}

{{ def }} close(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def operation_group_one(self):
else:
raise ValueError("API version {} does not have operation group 'operation_group_one'".format(api_version))
self._config.api_version = api_version
return OperationClass(self._client, self._config, Serializer(self._models_dict(api_version)), Deserializer(self._models_dict(api_version)))
return OperationClass(self._client, self._config, Serializer(self._models_dict(api_version)), Deserializer(self._models_dict(api_version)), api_version)

@property
def operation_group_two(self):
Expand All @@ -144,7 +144,7 @@ def operation_group_two(self):
else:
raise ValueError("API version {} does not have operation group 'operation_group_two'".format(api_version))
self._config.api_version = api_version
return OperationClass(self._client, self._config, Serializer(self._models_dict(api_version)), Deserializer(self._models_dict(api_version)))
return OperationClass(self._client, self._config, Serializer(self._models_dict(api_version)), Deserializer(self._models_dict(api_version)), api_version)

def close(self):
self._client.close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def operation_group_one(self):
else:
raise ValueError("API version {} does not have operation group 'operation_group_one'".format(api_version))
self._config.api_version = api_version
return OperationClass(self._client, self._config, Serializer(self._models_dict(api_version)), Deserializer(self._models_dict(api_version)))
return OperationClass(self._client, self._config, Serializer(self._models_dict(api_version)), Deserializer(self._models_dict(api_version)), api_version)

@property
def operation_group_two(self):
Expand All @@ -144,7 +144,7 @@ def operation_group_two(self):
else:
raise ValueError("API version {} does not have operation group 'operation_group_two'".format(api_version))
self._config.api_version = api_version
return OperationClass(self._client, self._config, Serializer(self._models_dict(api_version)), Deserializer(self._models_dict(api_version)))
return OperationClass(self._client, self._config, Serializer(self._models_dict(api_version)), Deserializer(self._models_dict(api_version)), api_version)

async def close(self):
await self._client.close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, credential: "TokenCredential", base_url: str = "http://localh
self._deserialize = Deserializer(client_models)
self._serialize.client_side_validation = False
self.operation_group_one = OperationGroupOneOperations(
self._client, self._config, self._serialize, self._deserialize
self._client, self._config, self._serialize, self._deserialize, "0.0.0"
)

def _send_request(self, request: HttpRequest, **kwargs: Any) -> HttpResponse:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(
self._deserialize = Deserializer(client_models)
self._serialize.client_side_validation = False
self.operation_group_one = OperationGroupOneOperations(
self._client, self._config, self._serialize, self._deserialize
self._client, self._config, self._serialize, self._deserialize, "0.0.0"
)

def _send_request(self, request: HttpRequest, **kwargs: Any) -> Awaitable[AsyncHttpResponse]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(self, *args, **kwargs) -> None:
self._config = input_args.pop(0) if input_args else kwargs.pop("config")
self._serialize = input_args.pop(0) if input_args else kwargs.pop("serializer")
self._deserialize = input_args.pop(0) if input_args else kwargs.pop("deserializer")
self._api_version = input_args.pop(0) if input_args else kwargs.pop("api_version")

@distributed_trace_async
async def test_two(self, **kwargs: Any) -> None: # pylint: disable=inconsistent-return-statements
Expand All @@ -70,7 +71,7 @@ async def test_two(self, **kwargs: Any) -> None: # pylint: disable=inconsistent
_headers = kwargs.pop("headers", {}) or {}
_params = case_insensitive_dict(kwargs.pop("params", {}) or {})

api_version: str = kwargs.pop("api_version", _params.pop("api-version", "0.0.0"))
api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._api_version or "0.0.0"))
cls: ClsType[None] = kwargs.pop("cls", None)

request = build_test_two_request(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(self, *args, **kwargs):
self._config = input_args.pop(0) if input_args else kwargs.pop("config")
self._serialize = input_args.pop(0) if input_args else kwargs.pop("serializer")
self._deserialize = input_args.pop(0) if input_args else kwargs.pop("deserializer")
self._api_version = input_args.pop(0) if input_args else kwargs.pop("api_version")

@distributed_trace
def test_two(self, **kwargs: Any) -> None: # pylint: disable=inconsistent-return-statements
Expand All @@ -92,7 +93,7 @@ def test_two(self, **kwargs: Any) -> None: # pylint: disable=inconsistent-retur
_headers = kwargs.pop("headers", {}) or {}
_params = case_insensitive_dict(kwargs.pop("params", {}) or {})

api_version: str = kwargs.pop("api_version", _params.pop("api-version", "0.0.0"))
api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._api_version or "0.0.0"))
cls: ClsType[None] = kwargs.pop("cls", None)

request = build_test_two_request(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(self, credential: "TokenCredential", base_url: str = "http://localh
self._deserialize = Deserializer(client_models)
self._serialize.client_side_validation = False
self.operation_group_one = OperationGroupOneOperations(
self._client, self._config, self._serialize, self._deserialize
self._client, self._config, self._serialize, self._deserialize, "1.0.0"
)

def _send_request(self, request: HttpRequest, **kwargs: Any) -> HttpResponse:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
self._deserialize = Deserializer(client_models)
self._serialize.client_side_validation = False
self.operation_group_one = OperationGroupOneOperations(
self._client, self._config, self._serialize, self._deserialize
self._client, self._config, self._serialize, self._deserialize, "1.0.0"
)

def _send_request(self, request: HttpRequest, **kwargs: Any) -> Awaitable[AsyncHttpResponse]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@


class MultiapiServiceClientOperationsMixin(MultiapiServiceClientMixinABC):
@property
def _api_version(self) -> str:
return self._config.api_version

@distributed_trace_async
async def test_one( # pylint: disable=inconsistent-return-statements
self, id: int, message: Optional[str] = None, **kwargs: Any
Expand All @@ -70,7 +74,7 @@ async def test_one( # pylint: disable=inconsistent-return-statements
_headers = kwargs.pop("headers", {}) or {}
_params = case_insensitive_dict(kwargs.pop("params", {}) or {})

api_version: str = kwargs.pop("api_version", _params.pop("api-version", "1.0.0"))
api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._api_version or "1.0.0"))
cls: ClsType[None] = kwargs.pop("cls", None)

request = build_test_one_request(
Expand Down Expand Up @@ -498,7 +502,7 @@ async def test_different_calls( # pylint: disable=inconsistent-return-statement
_headers = kwargs.pop("headers", {}) or {}
_params = case_insensitive_dict(kwargs.pop("params", {}) or {})

api_version: str = kwargs.pop("api_version", _params.pop("api-version", "1.0.0"))
api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._api_version or "1.0.0"))
cls: ClsType[None] = kwargs.pop("cls", None)

request = build_test_different_calls_request(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(self, *args, **kwargs) -> None:
self._config = input_args.pop(0) if input_args else kwargs.pop("config")
self._serialize = input_args.pop(0) if input_args else kwargs.pop("serializer")
self._deserialize = input_args.pop(0) if input_args else kwargs.pop("deserializer")
self._api_version = input_args.pop(0) if input_args else kwargs.pop("api_version")

@distributed_trace_async
async def test_two(self, **kwargs: Any) -> None: # pylint: disable=inconsistent-return-statements
Expand All @@ -71,7 +72,7 @@ async def test_two(self, **kwargs: Any) -> None: # pylint: disable=inconsistent
_headers = kwargs.pop("headers", {}) or {}
_params = case_insensitive_dict(kwargs.pop("params", {}) or {})

api_version: str = kwargs.pop("api_version", _params.pop("api-version", "1.0.0"))
api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._api_version or "1.0.0"))
cls: ClsType[None] = kwargs.pop("cls", None)

request = build_test_two_request(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ def build_test_different_calls_request(*, greeting_in_english: str, **kwargs: An


class MultiapiServiceClientOperationsMixin(MultiapiServiceClientMixinABC):
@property
def _api_version(self) -> str:
return self._config.api_version

@distributed_trace
def test_one( # pylint: disable=inconsistent-return-statements
self, id: int, message: Optional[str] = None, **kwargs: Any
Expand All @@ -147,7 +151,7 @@ def test_one( # pylint: disable=inconsistent-return-statements
_headers = kwargs.pop("headers", {}) or {}
_params = case_insensitive_dict(kwargs.pop("params", {}) or {})

api_version: str = kwargs.pop("api_version", _params.pop("api-version", "1.0.0"))
api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._api_version or "1.0.0"))
cls: ClsType[None] = kwargs.pop("cls", None)

request = build_test_one_request(
Expand Down Expand Up @@ -572,7 +576,7 @@ def test_different_calls( # pylint: disable=inconsistent-return-statements
_headers = kwargs.pop("headers", {}) or {}
_params = case_insensitive_dict(kwargs.pop("params", {}) or {})

api_version: str = kwargs.pop("api_version", _params.pop("api-version", "1.0.0"))
api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._api_version or "1.0.0"))
cls: ClsType[None] = kwargs.pop("cls", None)

request = build_test_different_calls_request(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(self, *args, **kwargs):
self._config = input_args.pop(0) if input_args else kwargs.pop("config")
self._serialize = input_args.pop(0) if input_args else kwargs.pop("serializer")
self._deserialize = input_args.pop(0) if input_args else kwargs.pop("deserializer")
self._api_version = input_args.pop(0) if input_args else kwargs.pop("api_version")

@distributed_trace
def test_two(self, **kwargs: Any) -> None: # pylint: disable=inconsistent-return-statements
Expand All @@ -92,7 +93,7 @@ def test_two(self, **kwargs: Any) -> None: # pylint: disable=inconsistent-retur
_headers = kwargs.pop("headers", {}) or {}
_params = case_insensitive_dict(kwargs.pop("params", {}) or {})

api_version: str = kwargs.pop("api_version", _params.pop("api-version", "1.0.0"))
api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._api_version or "1.0.0"))
cls: ClsType[None] = kwargs.pop("cls", None)

request = build_test_two_request(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ def __init__(self, credential: "TokenCredential", base_url: str = "http://localh
self._deserialize = Deserializer(client_models)
self._serialize.client_side_validation = False
self.operation_group_one = OperationGroupOneOperations(
self._client, self._config, self._serialize, self._deserialize
self._client, self._config, self._serialize, self._deserialize, "2.0.0"
)
self.operation_group_two = OperationGroupTwoOperations(
self._client, self._config, self._serialize, self._deserialize
self._client, self._config, self._serialize, self._deserialize, "2.0.0"
)

def _send_request(self, request: HttpRequest, **kwargs: Any) -> HttpResponse:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ def __init__(
self._deserialize = Deserializer(client_models)
self._serialize.client_side_validation = False
self.operation_group_one = OperationGroupOneOperations(
self._client, self._config, self._serialize, self._deserialize
self._client, self._config, self._serialize, self._deserialize, "2.0.0"
)
self.operation_group_two = OperationGroupTwoOperations(
self._client, self._config, self._serialize, self._deserialize
self._client, self._config, self._serialize, self._deserialize, "2.0.0"
)

def _send_request(self, request: HttpRequest, **kwargs: Any) -> Awaitable[AsyncHttpResponse]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@


class MultiapiServiceClientOperationsMixin(MultiapiServiceClientMixinABC):
@property
def _api_version(self) -> str:
return self._config.api_version

@distributed_trace_async
async def test_one(self, id: int, message: Optional[str] = None, **kwargs: Any) -> _models.ModelTwo:
"""TestOne should be in an SecondVersionOperationsMixin. Returns ModelTwo.
Expand All @@ -57,7 +61,7 @@ async def test_one(self, id: int, message: Optional[str] = None, **kwargs: Any)
_headers = kwargs.pop("headers", {}) or {}
_params = case_insensitive_dict(kwargs.pop("params", {}) or {})

api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2.0.0"))
api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._api_version or "2.0.0"))
cls: ClsType[_models.ModelTwo] = kwargs.pop("cls", None)

request = build_test_one_request(
Expand Down Expand Up @@ -118,7 +122,7 @@ async def test_different_calls( # pylint: disable=inconsistent-return-statement
_headers = kwargs.pop("headers", {}) or {}
_params = case_insensitive_dict(kwargs.pop("params", {}) or {})

api_version: str = kwargs.pop("api_version", _params.pop("api-version", "2.0.0"))
api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._api_version or "2.0.0"))
cls: ClsType[None] = kwargs.pop("cls", None)

request = build_test_different_calls_request(
Expand Down
Loading