Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions autorest/codegen/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
# --------------------------------------------------------------------------
from typing import Any, Dict
from .base_model import BaseModel
from .code_model import CodeModel, CredentialSchema
from .code_model import CodeModel
from .credential_schema import AzureKeyCredentialSchema, TokenCredentialSchema
from .object_schema import ObjectSchema
from .dictionary_schema import DictionarySchema
from .list_schema import ListSchema
Expand All @@ -25,10 +26,10 @@


__all__ = [
"AzureKeyCredentialSchema",
"BaseModel",
"BaseSchema",
"CodeModel",
"CredentialSchema",
"ConstantSchema",
"ObjectSchema",
"DictionarySchema",
Expand All @@ -45,7 +46,8 @@
"ParameterList",
"OperationGroup",
"Property",
"SchemaResponse"
"SchemaResponse",
"TokenCredentialSchema",
]

def _generate_as_object_schema(yaml_data: Dict[str, Any]) -> bool:
Expand Down
54 changes: 7 additions & 47 deletions autorest/codegen/models/code_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
# --------------------------------------------------------------------------
from itertools import chain
import logging
from typing import cast, List, Dict, Optional, Any, Set
from typing import cast, List, Dict, Optional, Any, Set, Union

from .base_schema import BaseSchema
from .credential_schema import AzureKeyCredentialSchema, TokenCredentialSchema
from .enum_schema import EnumSchema
from .object_schema import ObjectSchema
from .operation_group import OperationGroup
Expand All @@ -17,7 +18,6 @@
from .parameter import Parameter, ParameterLocation
from .client import Client
from .parameter_list import ParameterList
from .imports import FileImport, ImportType, TypingSection
from .schema_response import SchemaResponse
from .property import Property
from .primitive_schemas import IOSchema
Expand All @@ -26,50 +26,6 @@
_LOGGER = logging.getLogger(__name__)


class CredentialSchema(BaseSchema):
def __init__(self, async_mode) -> None: # pylint: disable=super-init-not-called
self.async_mode = async_mode
self.async_type = "~azure.core.credentials_async.AsyncTokenCredential"
self.sync_type = "~azure.core.credentials.TokenCredential"
self.default_value = None

@property
def serialization_type(self) -> str:
if self.async_mode:
return self.async_type
return self.sync_type

@property
def docstring_type(self) -> str:
return self.serialization_type

@property
def type_annotation(self) -> str:
if self.async_mode:
return '"AsyncTokenCredential"'
return '"TokenCredential"'

@property
def docstring_text(self) -> str:
return "credential"

def imports(self) -> FileImport:
file_import = FileImport()
if self.async_mode:
file_import.add_from_import(
"azure.core.credentials_async", "AsyncTokenCredential",
ImportType.AZURECORE,
typing_section=TypingSection.TYPING
)
else:
file_import.add_from_import(
"azure.core.credentials", "TokenCredential",
ImportType.AZURECORE,
typing_section=TypingSection.TYPING
)
return file_import


class CodeModel: # pylint: disable=too-many-instance-attributes
"""Holds all of the information we have parsed out of the yaml file. The CodeModel is what gets
serialized by the serializers.
Expand Down Expand Up @@ -168,7 +124,11 @@ def add_credential_global_parameter(self) -> None:
:return: None
:rtype: None
"""
credential_schema = CredentialSchema(async_mode=False)
credential_schema: Union[AzureKeyCredentialSchema, TokenCredentialSchema]
if self.options["credential_default_policy_type"] == "BearerTokenCredentialPolicy":
credential_schema = TokenCredentialSchema(async_mode=False)
else:
credential_schema = AzureKeyCredentialSchema()
credential_parameter = Parameter(
yaml_data={},
schema=credential_schema,
Expand Down
83 changes: 83 additions & 0 deletions autorest/codegen/models/credential_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from .base_schema import BaseSchema
from .imports import FileImport, ImportType, TypingSection

class CredentialSchema(BaseSchema):
def __init__(self) -> None: # pylint: disable=super-init-not-called
self.default_value = None

@property
def docstring_type(self) -> str:
return self.serialization_type

@property
def docstring_text(self) -> str:
return "credential"

@property
def serialization_type(self) -> str:
# this property is added, because otherwise pylint says that
# abstract serialization_type in BaseSchema is not overridden
pass


class AzureKeyCredentialSchema(CredentialSchema):

@property
def serialization_type(self) -> str:
return "~azure.core.credentials.AzureKeyCredential"

@property
def type_annotation(self) -> str:
return "AzureKeyCredential"

def imports(self) -> FileImport:
file_import = FileImport()
file_import.add_from_import(
"azure.core.credentials",
"AzureKeyCredential",
ImportType.AZURECORE,
typing_section=TypingSection.CONDITIONAL
)
return file_import


class TokenCredentialSchema(CredentialSchema):
def __init__(self, async_mode) -> None:
super(TokenCredentialSchema, self).__init__()
self.async_mode = async_mode
self.async_type = "~azure.core.credentials_async.AsyncTokenCredential"
self.sync_type = "~azure.core.credentials.TokenCredential"

@property
def serialization_type(self) -> str:
if self.async_mode:
return self.async_type
return self.sync_type

@property
def type_annotation(self) -> str:
if self.async_mode:
return '"AsyncTokenCredential"'
return '"TokenCredential"'


def imports(self) -> FileImport:
file_import = FileImport()
if self.async_mode:
file_import.add_from_import(
"azure.core.credentials_async", "AsyncTokenCredential",
ImportType.AZURECORE,
typing_section=TypingSection.TYPING
)
else:
file_import.add_from_import(
"azure.core.credentials", "TokenCredential",
ImportType.AZURECORE,
typing_section=TypingSection.TYPING
)
return file_import
16 changes: 11 additions & 5 deletions autorest/codegen/serializers/general_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# --------------------------------------------------------------------------
from jinja2 import Environment
from .import_serializer import FileImportSerializer, TypingSection
from ..models import FileImport, ImportType, CodeModel, CredentialSchema
from ..models import FileImport, ImportType, CodeModel, TokenCredentialSchema


class GeneralSerializer:
Expand All @@ -24,9 +24,9 @@ def serialize_init_file(self) -> str:

def _correct_credential_parameter(self):
credential_param = [
gp for gp in self.code_model.global_parameters.parameters if isinstance(gp.schema, CredentialSchema)
gp for gp in self.code_model.global_parameters.parameters if isinstance(gp.schema, TokenCredentialSchema)
][0]
credential_param.schema = CredentialSchema(async_mode=self.async_mode)
credential_param.schema = TokenCredentialSchema(async_mode=self.async_mode)

def serialize_service_client_file(self) -> str:
def _service_client_imports() -> FileImport:
Expand All @@ -37,7 +37,10 @@ def _service_client_imports() -> FileImport:

template = self.env.get_template("service_client.py.jinja2")

if self.code_model.options['credential']:
if (
self.code_model.options['credential'] and
self.code_model.options['credential_default_policy_type'] == "BearerTokenCredentialPolicy"
):
self._correct_credential_parameter()

return template.render(
Expand Down Expand Up @@ -68,7 +71,10 @@ def _config_imports(async_mode: bool) -> FileImport:
package_name = package_name[len("azure-"):]
sdk_moniker = package_name if package_name else self.code_model.class_name.lower()

if self.code_model.options['credential']:
if (
self.code_model.options['credential'] and
self.code_model.options['credential_default_policy_type'] == "BearerTokenCredentialPolicy"
):
self._correct_credential_parameter()

template = self.env.get_template("config.py.jinja2")
Expand Down
13 changes: 8 additions & 5 deletions autorest/codegen/serializers/metadata_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@
OperationGroup,
LROOperation,
PagingOperation,
CredentialSchema,
TokenCredentialSchema,
ParameterList,
TypingSection,
ImportType
)

def _correct_credential_parameter(global_parameters: ParameterList, async_mode: bool) -> None:
credential_param = [
gp for gp in global_parameters.parameters if isinstance(gp.schema, CredentialSchema)
gp for gp in global_parameters.parameters if isinstance(gp.schema, TokenCredentialSchema)
][0]
credential_param.schema = CredentialSchema(async_mode=async_mode)
credential_param.schema = TokenCredentialSchema(async_mode=async_mode)

def _json_serialize_imports(
imports: Dict[TypingSection, Dict[ImportType, Dict[str, Set[Optional[str]]]]]
Expand Down Expand Up @@ -107,8 +107,11 @@ def _is_paging(operation):
# In this case, we need two copies of the credential global parameter
# for typing purposes.
async_global_parameters = self.code_model.global_parameters
if self.code_model.options['credential']:
# this ensures that the CredentialSchema showing up in the list of code model's global parameters
if (
self.code_model.options['credential'] and
self.code_model.options['credential_default_policy_type'] == "BearerTokenCredentialPolicy"
):
# this ensures that the TokenCredentialSchema showing up in the list of code model's global parameters
# is sync. This way we only have to make a copy for an async_credential
_correct_credential_parameter(self.code_model.global_parameters, False)
async_global_parameters = self._make_async_copy_of_global_parameters()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# pylint: disable=unused-import,ungrouped-imports
from typing import Any, Dict, Optional

from azure.core.credentials import TokenCredential
from azure.core.credentials import AzureKeyCredential

from ._configuration import AutoRestHeadTestServiceConfiguration
from .operations import HttpSuccessOperations
Expand All @@ -27,13 +27,13 @@ class AutoRestHeadTestService(object):
:ivar http_success: HttpSuccessOperations operations
:vartype http_success: headwithazurekeycredentialpolicy.operations.HttpSuccessOperations
:param credential: Credential needed for the client to connect to Azure.
:type credential: ~azure.core.credentials.TokenCredential
:type credential: ~azure.core.credentials.AzureKeyCredential
:param str base_url: Service URL
"""

def __init__(
self,
credential, # type: "TokenCredential"
credential, # type: AzureKeyCredential
base_url=None, # type: Optional[str]
**kwargs # type: Any
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# pylint: disable=unused-import,ungrouped-imports
from typing import Any

from azure.core.credentials import TokenCredential
from azure.core.credentials import AzureKeyCredential


class AutoRestHeadTestServiceConfiguration(Configuration):
Expand All @@ -28,12 +28,12 @@ class AutoRestHeadTestServiceConfiguration(Configuration):
attributes.

:param credential: Credential needed for the client to connect to Azure.
:type credential: ~azure.core.credentials.TokenCredential
:type credential: ~azure.core.credentials.AzureKeyCredential
"""

def __init__(
self,
credential, # type: "TokenCredential"
credential, # type: AzureKeyCredential
**kwargs # type: Any
):
# type: (...) -> None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,14 @@

from typing import Any, Optional, TYPE_CHECKING

from azure.core.credentials import AzureKeyCredential
from azure.mgmt.core import AsyncARMPipelineClient
from msrest import Deserializer, Serializer

if TYPE_CHECKING:
# pylint: disable=unused-import,ungrouped-imports
from typing import Dict

from azure.core.credentials_async import AsyncTokenCredential

from ._configuration_async import AutoRestHeadTestServiceConfiguration
from .operations_async import HttpSuccessOperations

Expand All @@ -27,13 +26,13 @@ class AutoRestHeadTestService(object):
:ivar http_success: HttpSuccessOperations operations
:vartype http_success: headwithazurekeycredentialpolicy.aio.operations_async.HttpSuccessOperations
:param credential: Credential needed for the client to connect to Azure.
:type credential: ~azure.core.credentials_async.AsyncTokenCredential
:type credential: ~azure.core.credentials.AzureKeyCredential
:param str base_url: Service URL
"""

def __init__(
self,
credential: "AsyncTokenCredential",
credential: AzureKeyCredential,
base_url: Optional[str] = None,
**kwargs: Any
) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,15 @@
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------

from typing import Any, TYPE_CHECKING
from typing import Any

from azure.core.configuration import Configuration
from azure.core.credentials import AzureKeyCredential
from azure.core.pipeline import policies
from azure.mgmt.core.policies import ARMHttpLoggingPolicy

from .._version import VERSION

if TYPE_CHECKING:
# pylint: disable=unused-import,ungrouped-imports
from azure.core.credentials_async import AsyncTokenCredential


class AutoRestHeadTestServiceConfiguration(Configuration):
"""Configuration for AutoRestHeadTestService.
Expand All @@ -26,12 +23,12 @@ class AutoRestHeadTestServiceConfiguration(Configuration):
attributes.

:param credential: Credential needed for the client to connect to Azure.
:type credential: ~azure.core.credentials_async.AsyncTokenCredential
:type credential: ~azure.core.credentials.AzureKeyCredential
"""

def __init__(
self,
credential: "AsyncTokenCredential",
credential: AzureKeyCredential,
**kwargs: Any
) -> None:
if credential is None:
Expand Down
Loading