Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 47 additions & 65 deletions autorest/codegen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# --------------------------------------------------------------------------
import logging
import sys
from typing import Dict, Any, Set, Union, List
from typing import Dict, Any, Set, Union, List, Type
import yaml

from .. import Plugin
Expand All @@ -16,14 +16,8 @@
from .models.parameter_list import GlobalParameterList
from .models.rest import Rest
from .serializers import JinjaSerializer


def _get_credential_default_policy_type_has_async_version(credential_default_policy_type: str) -> bool:
mapping = {
"BearerTokenCredentialPolicy": True,
"AzureKeyCredentialPolicy": False
}
return mapping[credential_default_policy_type]
from .models.credential_schema_policy import CredentialSchemaPolicy, get_credential_schema_policy_type
from .models.credential_schema import AzureKeyCredentialSchema, TokenCredentialSchema

def _build_convenience_layer(yaml_data: Dict[str, Any], code_model: CodeModel) -> None:
# Create operations
Expand Down Expand Up @@ -111,6 +105,8 @@ def _create_code_model(self, yaml_data: Dict[str, Any], options: Dict[str, Union
only_path_and_body_params_positional=only_path_and_body_params_positional,
options=options,
)
if code_model.options['credential']:
self._handle_default_authentication_policy(code_model)
code_model.module_name = yaml_data["info"]["python_title"]
code_model.class_name = yaml_data["info"]["pascal_case_title"]
code_model.description = (
Expand Down Expand Up @@ -176,70 +172,69 @@ def _get_credential_scopes(self, credential):
)
return credential_scopes

def _get_credential_param(self, azure_arm, credential, credential_default_policy_type):
credential_scopes = self._get_credential_scopes(credential)
def _initialize_credential_schema_policy(
self, code_model: CodeModel, credential_schema_policy: Type[CredentialSchemaPolicy]
) -> CredentialSchemaPolicy:
credential_scopes = self._get_credential_scopes(code_model.options['credential'])
credential_key_header_name = self._autorestapi.get_value('credential-key-header-name')
azure_arm = code_model.options['azure_arm']
credential = code_model.options['credential']

if credential_default_policy_type == "BearerTokenCredentialPolicy":
if hasattr(credential_schema_policy, "credential_scopes"):
if not credential_scopes:
if azure_arm:
credential_scopes = ["https://management.azure.com/.default"]
elif credential:
# If add-credential is specified, we still want to add a credential_scopes variable.
# Will make it an empty list so we can differentiate between this case and None
_LOGGER.warning(
"You have default credential policy BearerTokenCredentialPolicy"
"You have default credential policy %s "
"but not the --credential-scopes flag set while generating non-management plane code. "
"This is not recommend because it forces the customer to pass credential scopes "
"through kwargs if they want to authenticate."
"through kwargs if they want to authenticate.",
credential_schema_policy.name()
)
credential_scopes = []

if credential_key_header_name:
raise ValueError(
"You have passed in a credential key header name with default credential policy type "
"BearerTokenCredentialPolicy. This is not allowed, since credential key header name is tied with "
"AzureKeyCredentialPolicy. Instead, with this policy it is recommend you pass in "
"--credential-scopes."
)
else:
# currently the only other credential policy is AzureKeyCredentialPolicy
if credential_scopes:
raise ValueError(
"You have passed in credential scopes with default credential policy type "
"AzureKeyCredentialPolicy. This is not allowed, since credential scopes is tied with "
"BearerTokenCredentialPolicy. Instead, with this policy you must pass in "
"--credential-key-header-name."
)
if not credential_key_header_name:
credential_key_header_name = "api-key"
_LOGGER.info(
"Defaulting the AzureKeyCredentialPolicy header's name to 'api-key'"
f"{credential_schema_policy.name()}. This is not allowed, since credential key header "
"name is tied with AzureKeyCredentialPolicy. Instead, with this policy it is recommend you "
"pass in --credential-scopes."
)
return credential_scopes, credential_key_header_name

def _handle_default_authentication_policy(self, azure_arm, credential):

passed_in_credential_default_policy_type = (
self._autorestapi.get_value("credential-default-policy-type") or "BearerTokenCredentialPolicy"
)

# right now, we only allow BearerTokenCredentialPolicy and AzureKeyCredentialPolicy
allowed_policies = ["BearerTokenCredentialPolicy", "AzureKeyCredentialPolicy"]
try:
credential_default_policy_type = [
cp for cp in allowed_policies if cp.lower() == passed_in_credential_default_policy_type.lower()
][0]
except IndexError:
return credential_schema_policy(
credential=TokenCredentialSchema(async_mode=False),
credential_scopes=credential_scopes,
)
# currently the only other credential policy is AzureKeyCredentialPolicy
if credential_scopes:
raise ValueError(
"The credential you pass in with --credential-default-policy-type must be either "
"BearerTokenCredentialPolicy or AzureKeyCredentialPolicy"
"You have passed in credential scopes with default credential policy type "
"AzureKeyCredentialPolicy. This is not allowed, since credential scopes is tied with "
f"{code_model.default_authentication_policy.name()}. Instead, with this policy you must pass in "
"--credential-key-header-name."
)

credential_scopes, credential_key_header_name = self._get_credential_param(
azure_arm, credential, credential_default_policy_type
if not credential_key_header_name:
credential_key_header_name = "api-key"
_LOGGER.info(
"Defaulting the AzureKeyCredentialPolicy header's name to 'api-key'"
)
return credential_schema_policy(
credential=AzureKeyCredentialSchema(),
credential_key_header_name=credential_key_header_name,
)

return credential_default_policy_type, credential_scopes, credential_key_header_name
def _handle_default_authentication_policy(self, code_model: CodeModel):
credential_schema_policy_name = (
self._autorestapi.get_value("credential-default-policy-type") or
code_model.default_authentication_policy.name()
)
credential_schema_policy_type = get_credential_schema_policy_type(credential_schema_policy_name)
credential_schema_policy = self._initialize_credential_schema_policy(
code_model, credential_schema_policy_type
)
code_model.credential_schema_policy = credential_schema_policy


def _build_code_model_options(self) -> Dict[str, Any]:
Expand All @@ -251,13 +246,6 @@ def _build_code_model_options(self) -> Dict[str, Any]:
self._autorestapi.get_boolean_value("add-credential", False)
)

credential_default_policy_type, credential_scopes, credential_key_header_name = (
self._handle_default_authentication_policy(
azure_arm, credential
)
)


license_header = self._autorestapi.get_value("header-text")
if license_header:
license_header = license_header.replace("\n", "\n# ")
Expand All @@ -269,8 +257,6 @@ def _build_code_model_options(self) -> Dict[str, Any]:
options: Dict[str, Any] = {
"azure_arm": azure_arm,
"credential": credential,
"credential_scopes": credential_scopes,
"credential_key_header_name": credential_key_header_name,
"head_as_boolean": self._autorestapi.get_boolean_value("head-as-boolean", False),
"license_header": license_header,
"keep_version_file": self._autorestapi.get_boolean_value("keep-version-file", False),
Expand All @@ -282,10 +268,6 @@ def _build_code_model_options(self) -> Dict[str, Any]:
"client_side_validation": self._autorestapi.get_boolean_value("client-side-validation", False),
"tracing": self._autorestapi.get_boolean_value("trace", False),
"multiapi": self._autorestapi.get_boolean_value("multiapi", False),
"credential_default_policy_type": credential_default_policy_type,
"credential_default_policy_type_has_async_version": (
_get_credential_default_policy_type_has_async_version(credential_default_policy_type)
),
"polymorphic_examples": self._autorestapi.get_value("polymorphic-examples") or 5,
}

Expand Down
28 changes: 20 additions & 8 deletions autorest/codegen/models/code_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
# --------------------------------------------------------------------------
from itertools import chain
import logging
from typing import cast, List, Dict, Optional, Any, Set, Union
from typing import cast, List, Dict, Optional, Any, Set, Type

from .base_schema import BaseSchema
from .credential_schema import AzureKeyCredentialSchema, TokenCredentialSchema
from .credential_schema_policy import (
BearerTokenCredentialPolicy, CredentialSchemaPolicy
)
from .enum_schema import EnumSchema
from .object_schema import ObjectSchema
from .operation_group import OperationGroup
Expand Down Expand Up @@ -84,6 +86,7 @@ def __init__(
self.service_client: Client = Client(self, GlobalParameterList())
self._rest: Optional[Rest] = None
self.request_builder_ids: Dict[int, RequestBuilder] = {}
self._credential_schema_policy: Optional[CredentialSchemaPolicy] = None

@property
def global_parameters(self) -> GlobalParameterList:
Expand Down Expand Up @@ -158,14 +161,9 @@ def add_credential_global_parameter(self) -> None:
:return: None
:rtype: None
"""
credential_schema: Union[AzureKeyCredentialSchema, TokenCredentialSchema]
if self.options["credential_default_policy_type"] == "BearerTokenCredentialPolicy":
credential_schema = TokenCredentialSchema(async_mode=False)
else:
credential_schema = AzureKeyCredentialSchema()
credential_parameter = Parameter(
yaml_data={},
schema=credential_schema,
schema=self.credential_schema_policy.credential,
serialized_name="credential",
rest_api_name="credential",
implementation="Client",
Expand Down Expand Up @@ -217,6 +215,20 @@ def _lookup_operation(yaml_id: int) -> Operation:
operation for operation in operation_group.operations if operation not in next_operations
]

@property
def default_authentication_policy(self) -> Type[CredentialSchemaPolicy]:
return BearerTokenCredentialPolicy

@property
def credential_schema_policy(self) -> CredentialSchemaPolicy:
if not self._credential_schema_policy:
raise ValueError("You want to find the Credential Schema Policy, but have not given a value")
return self._credential_schema_policy

@credential_schema_policy.setter
def credential_schema_policy(self, val: CredentialSchemaPolicy) -> None:
self._credential_schema_policy = val

@staticmethod
def _add_properties_from_inheritance_helper(schema, properties) -> List[Property]:
if not schema.base_models:
Expand Down
70 changes: 70 additions & 0 deletions autorest/codegen/models/credential_schema_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from abc import abstractmethod
from typing import List
from .credential_schema import CredentialSchema

class CredentialSchemaPolicy:

def __init__(self, credential: CredentialSchema, *args, **kwargs) -> None: # pylint: disable=unused-argument
self.credential = credential

@abstractmethod
def call(self, async_mode: bool) -> str:
...

@classmethod
def name(cls):
return cls.__name__


class BearerTokenCredentialPolicy(CredentialSchemaPolicy):

def __init__(
self,
credential: CredentialSchema,
credential_scopes: List[str]
) -> None:
super().__init__(credential)
self._credential_scopes = credential_scopes

@property
def credential_scopes(self):
return self._credential_scopes

def call(self, async_mode: bool) -> str:
policy_name = f"Async{self.name()}" if async_mode else self.name()
return f"policies.{policy_name}(self.credential, *self.credential_scopes, **kwargs)"


class AzureKeyCredentialPolicy(CredentialSchemaPolicy):

def __init__(
self,
credential: CredentialSchema,
credential_key_header_name: str
) -> None:
super().__init__(credential)
self._credential_key_header_name = credential_key_header_name

@property
def credential_key_header_name(self):
return self._credential_key_header_name

def call(self, async_mode: bool) -> str:
return f'policies.AzureKeyCredentialPolicy(self.credential, "{self.credential_key_header_name}", **kwargs)'

def get_credential_schema_policy_type(name):
policies = [BearerTokenCredentialPolicy, AzureKeyCredentialPolicy]
try:
return next(p for p in policies if p.name().lower() == name.lower())
except StopIteration:
raise ValueError(
"The credential policy you pass in with --credential-default-policy-type must be either "
"{}".format(
" or ".join([p.name() for p in policies])
)
)
4 changes: 2 additions & 2 deletions autorest/codegen/serializers/general_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def serialize_service_client_file(self) -> str:

if (
self.code_model.options['credential'] and
self.code_model.options['credential_default_policy_type'] == "BearerTokenCredentialPolicy"
isinstance(self.code_model.credential_schema_policy.credential, TokenCredentialSchema)
):
self._correct_credential_parameter()

Expand All @@ -71,7 +71,7 @@ def serialize_config_file(self) -> str:

if (
self.code_model.options['credential'] and
self.code_model.options['credential_default_policy_type'] == "BearerTokenCredentialPolicy"
isinstance(self.code_model.credential_schema_policy.credential, TokenCredentialSchema)
):
self._correct_credential_parameter()

Expand Down
4 changes: 2 additions & 2 deletions autorest/codegen/serializers/metadata_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ def _is_paging(operation):
# for typing purposes.
async_global_parameters = self.code_model.global_parameters
if (
self.code_model.options["credential"]
and self.code_model.options["credential_default_policy_type"] == "BearerTokenCredentialPolicy"
self.code_model.options['credential'] and
isinstance(self.code_model.credential_schema_policy.credential, TokenCredentialSchema)
):
# this ensures that the TokenCredentialSchema showing up in the list of code model's global parameters
# is sync. This way we only have to make a copy for an async_credential
Expand Down
10 changes: 4 additions & 6 deletions autorest/codegen/templates/config.py.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ class {{ code_model.class_name }}Configuration(Configuration):
self.{{ constant_parameter.serialized_name }} = {{ constant_parameter.constant_declaration }}
{% endfor %}
{% endif %}
{% if code_model.options['credential_scopes'] is not none %}
self.credential_scopes = kwargs.pop('credential_scopes', {{ code_model.options['credential_scopes'] }})
{% if code_model.options['credential'] and code_model.credential_schema_policy.credential_scopes is defined %}
self.credential_scopes = kwargs.pop('credential_scopes', {{ code_model.credential_schema_policy.credential_scopes }})
{% endif %}
kwargs.setdefault('sdk_moniker', '{{ sdk_moniker }}/{}'.format(VERSION))
self._configure(**kwargs)
Expand All @@ -73,12 +73,10 @@ class {{ code_model.class_name }}Configuration(Configuration):
self.authentication_policy = kwargs.get('authentication_policy')
{% if code_model.options['credential'] %}
{# only adding this if credential_scopes is not passed during code generation #}
{% if code_model.options["credential_scopes"] is not none and code_model.options["credential_scopes"]|length == 0 %}
{% if code_model.credential_schema_policy.credential_scopes is defined and code_model.credential_schema_policy.credential_scopes|length == 0 %}
if not self.credential_scopes and not self.authentication_policy:
raise ValueError("You must provide either credential_scopes or authentication_policy as kwargs")
{% endif %}
if self.credential and not self.authentication_policy:
{% set credential_default_policy_type = ("Async" if (async_mode and code_model.options['credential_default_policy_type_has_async_version']) else "") + code_model.options['credential_default_policy_type'] %}
{% set credential_param_type = ("'" + code_model.options['credential_key_header_name'] + "', ") if code_model.options['credential_key_header_name'] else ("*self.credential_scopes, " if "BearerTokenCredentialPolicy" in credential_default_policy_type else "") %}
self.authentication_policy = policies.{{ credential_default_policy_type }}(self.credential, {{ credential_param_type if credential_param_type }}**kwargs)
self.authentication_policy = {{ code_model.credential_schema_policy.call(async_mode) }}
{% endif %}
7 changes: 3 additions & 4 deletions autorest/codegen/templates/metadata.json.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,9 @@
},
"config": {
"credential": {{ code_model.options['credential'] | tojson }},
"credential_scopes": {{ code_model.options['credential_scopes'] | tojson }},
"credential_default_policy_type": {{ code_model.options['credential_default_policy_type'] | tojson }},
"credential_default_policy_type_has_async_version": {{ code_model.options['credential_default_policy_type_has_async_version'] | tojson }},
"credential_key_header_name": {{ code_model.options['credential_key_header_name'] | tojson }},
"credential_scopes": {{ (code_model.credential_schema_policy.credential_scopes if code_model.options['credential'] and code_model.credential_schema_policy.credential_scopes is defined else None)| tojson}},
"credential_call_sync": {{ (code_model.credential_schema_policy.call(async_mode=False) if code_model.options['credential'] else None) | tojson }},
"credential_call_async": {{ (code_model.credential_schema_policy.call(async_mode=True) if code_model.options['credential'] else None) | tojson }},
"sync_imports": {{ sync_config_imports | tojson }},
"async_imports": {{ async_config_imports | tojson }}
},
Expand Down
Loading