Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions ChangeLog.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ Modelerfour version: 4.15.378
**New Features**

- Add support for `x-ms-text` XML extension #722
- Allow users to pass the name of the key header for `AzureKeyCredentialPolicy` during generation. To use, pass in
`AzureKeyCredentialPolicy` with the `--credential-default-policy-type` flag, and pass in the key header name using
the `--credential-key-header-name` flag #736

**Bug Fixes**

Expand Down
95 changes: 66 additions & 29 deletions autorest/codegen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,15 +122,7 @@ def _create_code_model(self, yaml_data: Dict[str, Any], options: Dict[str, Union

return code_model

def _build_code_model_options(self) -> Dict[str, Any]:
"""Build en options dict from the user input while running autorest.
"""
azure_arm = self._autorestapi.get_boolean_value("azure-arm", False)
credential = (
self._autorestapi.get_boolean_value("add-credentials", False) or
self._autorestapi.get_boolean_value("add-credential", False)
)

def _get_credential_scopes(self, credential):
credential_scopes_temp = self._autorestapi.get_value("credential-scopes")
credential_scopes = credential_scopes_temp.split(",") if credential_scopes_temp else None
if credential_scopes and not credential:
Expand All @@ -142,6 +134,50 @@ def _build_code_model_options(self) -> Dict[str, Any]:
"--credential-scopes takes a list of scopes in comma separated format. "
"For example: --credential-scopes=https://cognitiveservices.azure.com/.default"
)
return credential_scopes

def _get_credential_param(self, azure_arm, credential, credential_default_policy_type):
credential_scopes = self._get_credential_scopes(credential)
credential_key_header_name = self._autorestapi.get_value('credential-key-header-name')

if credential_default_policy_type == "BearerTokenCredentialPolicy":
if not credential_scopes:
if azure_arm:
credential_scopes = ["https://management.azure.com/.default"]
elif credential:
# If add-credential is specified, we still want to add a credential_scopes variable.
# Will make it an empty list so we can differentiate between this case and None
_LOGGER.warning(
"You have used the --add-credential flag but not the --credential-scopes flag "
"while generating non-management plane code. "
"This is not recommend because it forces the customer to pass credential scopes "
"through kwargs if they want to authenticate."
)
credential_scopes = []
if credential_key_header_name:
raise ValueError(
"You have passed in a credential key header name with default credential policy type "
"BearerTokenCredentialPolicy. This is not allowed, since credential key header name is tied with "
"AzureKeyCredentialPolicy. Instead, with this policy it is recommend you pass in "
"--credential-scopes."
)
else:
# currently the only other credential policy is AzureKeyCredentialPolicy
if credential_scopes:
raise ValueError(
"You have passed in credential scopes with default credential policy type "
"AzureKeyCredentialPolicy. This is not allowed, since credential scopes is tied with "
"BearerTokenCredentialPolicy. Instead, with this policy you must pass in "
"--credential-key-header-name."
)
if not credential_key_header_name:
raise ValueError(
"With default credential policy type AzureKeyCredentialPolicy, you must pass in the name "
"of the key header with the flag --credential-key-header-name"
)
return credential_scopes, credential_key_header_name

def _handle_default_authentication_policy(self, azure_arm, credential):

passed_in_credential_default_policy_type = (
self._autorestapi.get_value("credential-default-policy-type") or "BearerTokenCredentialPolicy"
Expand All @@ -159,27 +195,27 @@ def _build_code_model_options(self) -> Dict[str, Any]:
"BearerTokenCredentialPolicy or AzureKeyCredentialPolicy"
)

if credential_scopes and credential_default_policy_type != "BearerTokenCredentialPolicy":
_LOGGER.warning(
"You have --credential-default-policy-type not set as BearerTokenCredentialPolicy and a value for "
"--credential-scopes. Since credential scopes are tied to the BearerTokenCredentialPolicy, "
"we will ignore your credential scopes."
credential_scopes, credential_key_header_name = self._get_credential_param(
azure_arm, credential, credential_default_policy_type
)

return credential_default_policy_type, credential_scopes, credential_key_header_name


def _build_code_model_options(self) -> Dict[str, Any]:
"""Build en options dict from the user input while running autorest.
"""
azure_arm = self._autorestapi.get_boolean_value("azure-arm", False)
credential = (
self._autorestapi.get_boolean_value("add-credentials", False) or
self._autorestapi.get_boolean_value("add-credential", False)
)

credential_default_policy_type, credential_scopes, credential_key_header_name = (
self._handle_default_authentication_policy(
azure_arm, credential
)
credential_scopes = []

elif not credential_scopes and credential_default_policy_type == "BearerTokenCredentialPolicy":
if azure_arm:
credential_scopes = ["https://management.azure.com/.default"]
elif credential:
# If add-credential is specified, we still want to add a credential_scopes variable.
# Will make it an empty list so we can differentiate between this case and None
_LOGGER.warning(
"You have used the --add-credential flag but not the --credential-scopes flag "
"while generating non-management plane code. "
"This is not recommend because it forces the customer to pass credential scopes "
"through kwargs if they want to authenticate."
)
credential_scopes = []
)


license_header = self._autorestapi.get_value("header-text")
Expand All @@ -194,6 +230,7 @@ def _build_code_model_options(self) -> Dict[str, Any]:
"azure_arm": azure_arm,
"credential": credential,
"credential_scopes": credential_scopes,
"credential_key_header_name": credential_key_header_name,
"head_as_boolean": self._autorestapi.get_boolean_value("head-as-boolean", False),
"license_header": license_header,
"keep_version_file": self._autorestapi.get_boolean_value("keep-version-file", False),
Expand Down
4 changes: 2 additions & 2 deletions autorest/codegen/templates/config.py.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,6 @@ class {{ code_model.class_name }}Configuration(Configuration):
{% endif %}
if self.credential and not self.authentication_policy:
{% set credential_default_policy_type = ("Async" if (async_mode and code_model.options['credential_default_policy_type_has_async_version']) else "") + code_model.options['credential_default_policy_type'] %}
{% set bearer_token_specific_params = "*self.credential_scopes, " %}
self.authentication_policy = policies.{{ credential_default_policy_type }}(self.credential, {{ bearer_token_specific_params if "BearerTokenCredentialPolicy" in credential_default_policy_type }}**kwargs)
{% set credential_param_type = ("'" + code_model.options['credential_key_header_name'] + "', ") if code_model.options['credential_key_header_name'] else ("*self.credential_scopes, " if "BearerTokenCredentialPolicy" in credential_default_policy_type else "") %}
self.authentication_policy = policies.{{ credential_default_policy_type }}(self.credential, {{ credential_param_type if credential_param_type }}**kwargs)
{% endif %}
3 changes: 2 additions & 1 deletion autorest/codegen/templates/metadata.json.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@
"credential": {{ code_model.options['credential'] | tojson }},
"credential_scopes": {{ code_model.options['credential_scopes'] | tojson }},
"credential_default_policy_type": {{ code_model.options['credential_default_policy_type'] | tojson }},
"credential_default_policy_type_has_async_version": {{ code_model.options['credential_default_policy_type_has_async_version'] | tojson }}
"credential_default_policy_type_has_async_version": {{ code_model.options['credential_default_policy_type_has_async_version'] | tojson }},
"credential_key_header_name": {{ code_model.options['credential_key_header_name'] | tojson }}
},
"operation_groups": {
{% for operation_group in code_model.operation_groups %}
Expand Down
4 changes: 2 additions & 2 deletions autorest/multiapi/templates/multiapi_config.py.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,6 @@ class {{ client_name }}Configuration(Configuration):
{% endif %}
if self.credential and not self.authentication_policy:
{% set credential_default_policy_type = ("Async" if (async_mode and config['credential_default_policy_type_has_async_version']) else "") + config['credential_default_policy_type'] %}
{% set bearer_token_specific_params = "*self.credential_scopes, " %}
self.authentication_policy = policies.{{ credential_default_policy_type }}(self.credential, {{ bearer_token_specific_params if "BearerTokenCredentialPolicy" in credential_default_policy_type }}**kwargs)
{% set credential_param_type = ("'" + config['credential_key_header_name'] + "', ") if config['credential_key_header_name'] else ("*self.credential_scopes, " if "BearerTokenCredentialPolicy" in credential_default_policy_type else "") %}
self.authentication_policy = policies.{{ credential_default_policy_type }}(self.credential, {{ credential_param_type if credential_param_type }}**kwargs)
{% endif %}
5 changes: 4 additions & 1 deletion tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ def regen_expected(c, opts, debug):
args.append(f"--override-info.description={opts['override-info.description']}")
if opts.get('credential-default-policy-type'):
args.append(f"--credential-default-policy-type={opts['credential-default-policy-type']}")
if opts.get('credential-key-header-name'):
args.append(f"--credential-key-header-name={opts['credential-key-header-name']}")
if opts.get('package-name'):
args.append(f"--package-name={opts['package-name']}")
if opts.get('override-client-name'):
Expand Down Expand Up @@ -262,7 +264,8 @@ def regenerate_credential_default_policy(c, debug=False):
'azure_arm': True,
'flattening_threshold': '1',
'ns_prefix': True,
'credential-default-policy-type': 'AzureKeyCredentialPolicy'
'credential-default-policy-type': 'AzureKeyCredentialPolicy',
'credential-key-header-name': 'Authorization'
}
regen_expected(c, opts, debug)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,4 @@ def _configure(
self.redirect_policy = kwargs.get('redirect_policy') or policies.RedirectPolicy(**kwargs)
self.authentication_policy = kwargs.get('authentication_policy')
if self.credential and not self.authentication_policy:
self.authentication_policy = policies.AzureKeyCredentialPolicy(self.credential, **kwargs)
self.authentication_policy = policies.AzureKeyCredentialPolicy(self.credential, 'Authorization', **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,4 @@ def _configure(
self.redirect_policy = kwargs.get('redirect_policy') or policies.AsyncRedirectPolicy(**kwargs)
self.authentication_policy = kwargs.get('authentication_policy')
if self.credential and not self.authentication_policy:
self.authentication_policy = policies.AzureKeyCredentialPolicy(self.credential, **kwargs)
self.authentication_policy = policies.AzureKeyCredentialPolicy(self.credential, 'Authorization', **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ async def default_client(credential, authentication_policy):
from multiapicredentialdefaultpolicy.aio import MultiapiServiceClient
async with MultiapiServiceClient(
base_url="http://localhost:3000",
credential="12345",
name="azure_key_credential_policy"
credential="12345"
) as default_client:
await yield_(default_client)

def test_multiapi_credential_default_policy_type(default_client):
# making sure that the authentication policy is AzureKeyCredentialPolicy
assert isinstance(default_client._config.authentication_policy, AzureKeyCredentialPolicy)
assert isinstance(default_client._config.authentication_policy, AzureKeyCredentialPolicy)
assert default_client._config.authentication_policy._name == "Authorization"
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ def default_client(authentication_policy):
from multiapicredentialdefaultpolicy import MultiapiServiceClient
with MultiapiServiceClient(
base_url="http://localhost:3000",
credential="12345",
name="azure_key_credential_policy"
credential="12345"
) as default_client:
yield default_client

def test_multiapi_credential_default_policy_type(default_client):
# making sure that the authentication policy is AzureKeyCredentialPolicy
assert isinstance(default_client._config.authentication_policy, AzureKeyCredentialPolicy)
assert isinstance(default_client._config.authentication_policy, AzureKeyCredentialPolicy)
assert default_client._config.authentication_policy._name == "Authorization"
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
"credential": true,
"credential_scopes": ["https://management.azure.com/.default"],
"credential_default_policy_type": "BearerTokenCredentialPolicy",
"credential_default_policy_type_has_async_version": true
"credential_default_policy_type_has_async_version": true,
"credential_key_header_name": null
},
"operation_groups": {
"operation_group_one": "OperationGroupOneOperations"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
"credential": true,
"credential_scopes": ["https://management.azure.com/.default"],
"credential_default_policy_type": "BearerTokenCredentialPolicy",
"credential_default_policy_type_has_async_version": true
"credential_default_policy_type_has_async_version": true,
"credential_key_header_name": null
},
"operation_groups": {
"operation_group_one": "OperationGroupOneOperations",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
"credential": true,
"credential_scopes": ["https://management.azure.com/.default"],
"credential_default_policy_type": "BearerTokenCredentialPolicy",
"credential_default_policy_type_has_async_version": true
"credential_default_policy_type_has_async_version": true,
"credential_key_header_name": null
},
"operation_groups": {
"operation_group_one": "OperationGroupOneOperations",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,4 @@ def _configure(
self.redirect_policy = kwargs.get('redirect_policy') or policies.RedirectPolicy(**kwargs)
self.authentication_policy = kwargs.get('authentication_policy')
if self.credential and not self.authentication_policy:
self.authentication_policy = policies.AzureKeyCredentialPolicy(self.credential, **kwargs)
self.authentication_policy = policies.AzureKeyCredentialPolicy(self.credential, 'Authorization', **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,4 @@ def _configure(
self.redirect_policy = kwargs.get('redirect_policy') or policies.AsyncRedirectPolicy(**kwargs)
self.authentication_policy = kwargs.get('authentication_policy')
if self.credential and not self.authentication_policy:
self.authentication_policy = policies.AzureKeyCredentialPolicy(self.credential, **kwargs)
self.authentication_policy = policies.AzureKeyCredentialPolicy(self.credential, 'Authorization', **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,4 @@ def _configure(
self.redirect_policy = kwargs.get('redirect_policy') or policies.RedirectPolicy(**kwargs)
self.authentication_policy = kwargs.get('authentication_policy')
if self.credential and not self.authentication_policy:
self.authentication_policy = policies.AzureKeyCredentialPolicy(self.credential, **kwargs)
self.authentication_policy = policies.AzureKeyCredentialPolicy(self.credential, 'Authorization', **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
"credential": true,
"credential_scopes": null,
"credential_default_policy_type": "AzureKeyCredentialPolicy",
"credential_default_policy_type_has_async_version": false
"credential_default_policy_type_has_async_version": false,
"credential_key_header_name": "Authorization"
},
"operation_groups": {
"operation_group_one": "OperationGroupOneOperations"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,4 @@ def _configure(
self.redirect_policy = kwargs.get('redirect_policy') or policies.AsyncRedirectPolicy(**kwargs)
self.authentication_policy = kwargs.get('authentication_policy')
if self.credential and not self.authentication_policy:
self.authentication_policy = policies.AzureKeyCredentialPolicy(self.credential, **kwargs)
self.authentication_policy = policies.AzureKeyCredentialPolicy(self.credential, 'Authorization', **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,4 @@ def _configure(
self.redirect_policy = kwargs.get('redirect_policy') or policies.RedirectPolicy(**kwargs)
self.authentication_policy = kwargs.get('authentication_policy')
if self.credential and not self.authentication_policy:
self.authentication_policy = policies.AzureKeyCredentialPolicy(self.credential, **kwargs)
self.authentication_policy = policies.AzureKeyCredentialPolicy(self.credential, 'Authorization', **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
"credential": true,
"credential_scopes": null,
"credential_default_policy_type": "AzureKeyCredentialPolicy",
"credential_default_policy_type_has_async_version": false
"credential_default_policy_type_has_async_version": false,
"credential_key_header_name": "Authorization"
},
"operation_groups": {
"operation_group_one": "OperationGroupOneOperations",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,4 @@ def _configure(
self.redirect_policy = kwargs.get('redirect_policy') or policies.AsyncRedirectPolicy(**kwargs)
self.authentication_policy = kwargs.get('authentication_policy')
if self.credential and not self.authentication_policy:
self.authentication_policy = policies.AzureKeyCredentialPolicy(self.credential, **kwargs)
self.authentication_policy = policies.AzureKeyCredentialPolicy(self.credential, 'Authorization', **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,4 @@ def _configure(
self.redirect_policy = kwargs.get('redirect_policy') or policies.RedirectPolicy(**kwargs)
self.authentication_policy = kwargs.get('authentication_policy')
if self.credential and not self.authentication_policy:
self.authentication_policy = policies.AzureKeyCredentialPolicy(self.credential, **kwargs)
self.authentication_policy = policies.AzureKeyCredentialPolicy(self.credential, 'Authorization', **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
"credential": true,
"credential_scopes": null,
"credential_default_policy_type": "AzureKeyCredentialPolicy",
"credential_default_policy_type_has_async_version": false
"credential_default_policy_type_has_async_version": false,
"credential_key_header_name": "Authorization"
},
"operation_groups": {
"operation_group_one": "OperationGroupOneOperations",
Expand Down
Loading