Skip to content

Commit

Permalink
Test generation for Mgmt SDK (#2677)
Browse files Browse the repository at this point in the history
* fix for async paging

* adopt for mpg

* fix

* changelog

* inv

* inv
  • Loading branch information
msyyc committed Jul 3, 2024
1 parent 48d1426 commit 6b9531a
Show file tree
Hide file tree
Showing 34 changed files with 2,397 additions and 75 deletions.
8 changes: 8 additions & 0 deletions .chronus/changes/test-mpg-2024-5-17-15-30-46.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
changeKind: feature
packages:
- "@autorest/python"
- "@azure-tools/typespec-python"
---

Enable test generation for ARM SDK
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ def _serialize_namespace_level(self, env: Environment, namespace_path: Path, cli
self.code_model.options["show_operations"]
and self.code_model.has_operations
and self.code_model.options["generate_test"]
and not self.code_model.options["azure_arm"]
):
self._serialize_and_write_test(env, namespace_path)

Expand Down Expand Up @@ -545,13 +544,14 @@ def _serialize_and_write_test(self, env: Environment, namespace_path: Path):
out_path = self._package_root_folder(namespace_path) / Path("generated_tests")
general_serializer = TestGeneralSerializer(code_model=self.code_model, env=env)
self.write_file(out_path / "conftest.py", general_serializer.serialize_conftest())
for is_async in (True, False):
async_suffix = "_async" if is_async else ""
general_serializer.is_async = is_async
self.write_file(
out_path / f"testpreparer{async_suffix}.py",
general_serializer.serialize_testpreparer(),
)
if not self.code_model.options["azure_arm"]:
for is_async in (True, False):
async_suffix = "_async" if is_async else ""
general_serializer.is_async = is_async
self.write_file(
out_path / f"testpreparer{async_suffix}.py",
general_serializer.serialize_testpreparer(),
)

for client in self.code_model.clients:
for og in client.operation_groups:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,26 @@
ModelType,
BaseType,
CombinedType,
FileImport,
)
from .utils import get_namespace_from_package_name, json_dumps_template


def is_lro(operation_type: str) -> bool:
return operation_type in ("lro", "lropaging")


def is_paging(operation_type: str) -> bool:
return operation_type in ("paging", "lropaging")


def is_common_operation(operation_type: str) -> bool:
return operation_type == "operation"


class TestName:
def __init__(self, client_name: str, *, is_async: bool = False) -> None:
def __init__(self, code_model: CodeModel, client_name: str, *, is_async: bool = False) -> None:
self.code_model = code_model
self.client_name = client_name
self.is_async = is_async

Expand All @@ -41,10 +55,14 @@ def prefix(self) -> str:

@property
def preparer_name(self) -> str:
if self.code_model.options["azure_arm"]:
return "RandomNameResourceGroupPreparer"
return self.prefix + "Preparer"

@property
def base_test_class_name(self) -> str:
if self.code_model.options["azure_arm"]:
return "AzureMgmtRecordedTestCase"
return f"{self.client_name}TestBase{self.async_suffix_capt}"


Expand All @@ -71,48 +89,43 @@ def operation_group_prefix(self) -> str:
@property
def response(self) -> str:
if self.is_async:
if self.operation.operation_type == "lropaging":
if is_lro(self.operation.operation_type):
return "response = await (await "
return "response = await "
if is_common_operation(self.operation.operation_type):
return "response = await "
return "response = "

@property
def lro_comment(self) -> str:
return " # poll until service return final result"
return " # call '.result()' to poll until service return final result"

@property
def operation_suffix(self) -> str:
if self.operation.operation_type == "lropaging":
if is_lro(self.operation.operation_type):
extra = ")" if self.is_async else ""
return f"{extra}.result(){self.lro_comment}"
return ""

@property
def extra_operation(self) -> str:
if self.is_async:
if self.operation.operation_type == "lro":
return f"result = await response.result(){self.lro_comment}"
if self.operation.operation_type == ("lropaging", "paging"):
return "result = [r async for r in response]"
else:
if self.operation.operation_type == "lro":
return f"result = response.result(){self.lro_comment}"
if self.operation.operation_type in ("lropaging", "paging"):
return "result = [r for r in response]"
if is_paging(self.operation.operation_type):
async_str = "async " if self.is_async else ""
return f"result = [r {async_str}for r in response]"
return ""


class Test(TestName):
def __init__(
self,
code_model: CodeModel,
client_name: str,
operation_group: OperationGroup,
testcases: List[TestCase],
test_class_name: str,
*,
is_async: bool = False,
) -> None:
super().__init__(client_name, is_async=is_async)
super().__init__(code_model, client_name, is_async=is_async)
self.operation_group = operation_group
self.testcases = testcases
self.test_class_name = test_class_name
Expand All @@ -129,19 +142,23 @@ def aio_str(self) -> str:

@property
def test_names(self) -> List[TestName]:
return [TestName(c.name, is_async=self.is_async) for c in self.code_model.clients]
return [TestName(self.code_model, c.name, is_async=self.is_async) for c in self.code_model.clients]

def add_import_client(self, imports: FileImport) -> None:
namespace = get_namespace_from_package_name(self.code_model.options["package_name"])
for client in self.code_model.clients:
imports.add_submodule_import(namespace + self.aio_str, client.name, ImportType.STDLIB)

@property
def import_clients(self) -> FileImportSerializer:
imports = self.init_file_import()
namespace = get_namespace_from_package_name(self.code_model.options["package_name"])

imports.add_submodule_import("devtools_testutils", "AzureRecordedTestCase", ImportType.STDLIB)
if not self.is_async:
imports.add_import("functools", ImportType.STDLIB)
imports.add_submodule_import("devtools_testutils", "PowerShellPreparer", ImportType.STDLIB)
for client in self.code_model.clients:
imports.add_submodule_import(namespace + self.aio_str, client.name, ImportType.STDLIB)
self.add_import_client(imports)

return FileImportSerializer(imports, self.is_async)

def serialize_conftest(self) -> str:
Expand Down Expand Up @@ -175,19 +192,25 @@ def __init__(
@property
def import_test(self) -> FileImportSerializer:
imports = self.init_file_import()
test_name = TestName(self.client.name, is_async=self.is_async)
test_name = TestName(self.code_model, self.client.name, is_async=self.is_async)
async_suffix = "_async" if self.is_async else ""
imports.add_submodule_import(
"testpreparer" + async_suffix,
"devtools_testutils" if self.code_model.options["azure_arm"] else "testpreparer" + async_suffix,
test_name.base_test_class_name,
ImportType.LOCAL,
)
imports.add_submodule_import("testpreparer", test_name.preparer_name, ImportType.LOCAL)
imports.add_submodule_import(
"devtools_testutils" if self.code_model.options["azure_arm"] else "testpreparer",
test_name.preparer_name,
ImportType.LOCAL,
)
imports.add_submodule_import(
"devtools_testutils" + self.aio_str,
"recorded_by_proxy" + async_suffix,
ImportType.LOCAL,
)
if self.code_model.options["azure_arm"]:
self.add_import_client(imports)
return FileImportSerializer(imports, self.is_async)

@property
Expand Down Expand Up @@ -242,6 +265,7 @@ def get_test(self) -> Test:
raise Exception("no public operation to test") # pylint: disable=broad-exception-raised

return Test(
code_model=self.code_model,
client_name=self.client.name,
operation_group=self.operation_group,
testcases=testcases,
Expand All @@ -251,7 +275,7 @@ def get_test(self) -> Test:

@property
def test_class_name(self) -> str:
test_name = TestName(self.client.name, is_async=self.is_async)
test_name = TestName(self.code_model, self.client.name, is_async=self.is_async)
class_name = "" if self.operation_group.is_mixin else self.operation_group.class_name
return f"Test{test_name.prefix}{class_name}{test_name.async_suffix_capt}"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ load_dotenv()
@pytest.fixture(scope="session", autouse=True)
def add_sanitizers(test_proxy):
{% for test_name in test_names %}
{% set prefix_upper = test_name.prefix|upper %}
{% set prefix_upper = "AZURE" if code_model.options["azure_arm"] else test_name.prefix|upper %}
{% set prefix_lower = test_name.prefix|lower %}
{{ prefix_lower }}_subscription_id = os.environ.get("{{ prefix_upper }}_SUBSCRIPTION_ID", "00000000-0000-0000-0000-000000000000")
{{ prefix_lower }}_tenant_id = os.environ.get("{{ prefix_upper }}_TENANT_ID", "00000000-0000-0000-0000-000000000000")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,46 @@
{% set prefix_lower = test.prefix|lower %}
{% set client_var = "self.client" if code_model.options["azure_arm"] else "client" %}
{% set async = "async " if test.is_async else "" %}
{% set async_suffix = "_async" if test.is_async else "" %}
# coding=utf-8
{{ code_model.options['license_header'] }}
import pytest
{{ imports }}

{% if code_model.options["azure_arm"] %}
AZURE_LOCATION = "eastus"
{% endif %}

@pytest.mark.skip("you may need to update the auto-generated test case before run it")
class {{ test.test_class_name }}({{ test.base_test_class_name }}):
{% if code_model.options["azure_arm"] %}
def setup_method(self, method):
{% if test.is_async %}
self.client = self.create_mgmt_client({{ test.client_name }}, is_async=True)
{% else %}
self.client = self.create_mgmt_client({{ test.client_name }})
{% endif %}
{% endif %}
{% for testcase in test.testcases %}
{% if code_model.options["azure_arm"] %}
@{{ test.preparer_name }}(location=AZURE_LOCATION)
{% else %}
@{{ test.preparer_name }}()
{% endif %}
@recorded_by_proxy{{ async_suffix }}
{% if code_model.options["azure_arm"] %}
{{ async }}def test_{{ testcase.operation.name }}(self, resource_group):
{% else %}
{{ async }}def test_{{ testcase.operation.name }}(self, {{ prefix_lower }}_endpoint):
client = self.{{ test.create_client_name }}(endpoint={{ prefix_lower }}_endpoint)
{{testcase.response }}client{{ testcase.operation_group_prefix }}.{{ testcase.operation.name }}(
{{ client_var }} = self.{{ test.create_client_name }}(endpoint={{ prefix_lower }}_endpoint)
{% endif %}
{{testcase.response }}{{ client_var }}{{ testcase.operation_group_prefix }}.{{ testcase.operation.name }}(
{% for key, value in testcase.params.items() %}
{% if code_model.options["azure_arm"] and key == "resource_group_name" %}
{{ key }}=resource_group.name,
{% else %}
{{ key }}={{ value|indent(12) }},
{% endif %}
{% endfor %}
){{ testcase.operation_suffix }}
{{ testcase.extra_operation }}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# coding=utf-8
# --------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# Code generated by Microsoft (R) Python Code Generator.
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------
import os
import pytest
from dotenv import load_dotenv
from devtools_testutils import (
test_proxy,
add_general_regex_sanitizer,
add_body_key_sanitizer,
add_header_regex_sanitizer,
)

load_dotenv()


# aovid record sensitive identity information in recordings
@pytest.fixture(scope="session", autouse=True)
def add_sanitizers(test_proxy):
resources_subscription_id = os.environ.get("AZURE_SUBSCRIPTION_ID", "00000000-0000-0000-0000-000000000000")
resources_tenant_id = os.environ.get("AZURE_TENANT_ID", "00000000-0000-0000-0000-000000000000")
resources_client_id = os.environ.get("AZURE_CLIENT_ID", "00000000-0000-0000-0000-000000000000")
resources_client_secret = os.environ.get("AZURE_CLIENT_SECRET", "00000000-0000-0000-0000-000000000000")
add_general_regex_sanitizer(regex=resources_subscription_id, value="00000000-0000-0000-0000-000000000000")
add_general_regex_sanitizer(regex=resources_tenant_id, value="00000000-0000-0000-0000-000000000000")
add_general_regex_sanitizer(regex=resources_client_id, value="00000000-0000-0000-0000-000000000000")
add_general_regex_sanitizer(regex=resources_client_secret, value="00000000-0000-0000-0000-000000000000")

add_header_regex_sanitizer(key="Set-Cookie", value="[set-cookie;]")
add_header_regex_sanitizer(key="Cookie", value="cookie;")
add_body_key_sanitizer(json_path="$..access_token", value="access_token")
Loading

0 comments on commit 6b9531a

Please sign in to comment.