Skip to content

Commit 06cab40

Browse files
authored
Add azure key credential param name (#736)
1 parent 0a343d3 commit 06cab40

File tree

36 files changed

+130
-68
lines changed

36 files changed

+130
-68
lines changed

ChangeLog.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ Modelerfour version: 4.15.378
77
**New Features**
88

99
- Add support for `x-ms-text` XML extension #722
10+
- Allow users to pass the name of the key header for `AzureKeyCredentialPolicy` during generation. To use, pass in
11+
`AzureKeyCredentialPolicy` with the `--credential-default-policy-type` flag, and pass in the key header name using
12+
the `--credential-key-header-name` flag #736
1013

1114
**Bug Fixes**
1215

autorest/codegen/__init__.py

Lines changed: 66 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -122,15 +122,7 @@ def _create_code_model(self, yaml_data: Dict[str, Any], options: Dict[str, Union
122122

123123
return code_model
124124

125-
def _build_code_model_options(self) -> Dict[str, Any]:
126-
"""Build en options dict from the user input while running autorest.
127-
"""
128-
azure_arm = self._autorestapi.get_boolean_value("azure-arm", False)
129-
credential = (
130-
self._autorestapi.get_boolean_value("add-credentials", False) or
131-
self._autorestapi.get_boolean_value("add-credential", False)
132-
)
133-
125+
def _get_credential_scopes(self, credential):
134126
credential_scopes_temp = self._autorestapi.get_value("credential-scopes")
135127
credential_scopes = credential_scopes_temp.split(",") if credential_scopes_temp else None
136128
if credential_scopes and not credential:
@@ -142,6 +134,50 @@ def _build_code_model_options(self) -> Dict[str, Any]:
142134
"--credential-scopes takes a list of scopes in comma separated format. "
143135
"For example: --credential-scopes=https://cognitiveservices.azure.com/.default"
144136
)
137+
return credential_scopes
138+
139+
def _get_credential_param(self, azure_arm, credential, credential_default_policy_type):
140+
credential_scopes = self._get_credential_scopes(credential)
141+
credential_key_header_name = self._autorestapi.get_value('credential-key-header-name')
142+
143+
if credential_default_policy_type == "BearerTokenCredentialPolicy":
144+
if not credential_scopes:
145+
if azure_arm:
146+
credential_scopes = ["https://management.azure.com/.default"]
147+
elif credential:
148+
# If add-credential is specified, we still want to add a credential_scopes variable.
149+
# Will make it an empty list so we can differentiate between this case and None
150+
_LOGGER.warning(
151+
"You have used the --add-credential flag but not the --credential-scopes flag "
152+
"while generating non-management plane code. "
153+
"This is not recommend because it forces the customer to pass credential scopes "
154+
"through kwargs if they want to authenticate."
155+
)
156+
credential_scopes = []
157+
if credential_key_header_name:
158+
raise ValueError(
159+
"You have passed in a credential key header name with default credential policy type "
160+
"BearerTokenCredentialPolicy. This is not allowed, since credential key header name is tied with "
161+
"AzureKeyCredentialPolicy. Instead, with this policy it is recommend you pass in "
162+
"--credential-scopes."
163+
)
164+
else:
165+
# currently the only other credential policy is AzureKeyCredentialPolicy
166+
if credential_scopes:
167+
raise ValueError(
168+
"You have passed in credential scopes with default credential policy type "
169+
"AzureKeyCredentialPolicy. This is not allowed, since credential scopes is tied with "
170+
"BearerTokenCredentialPolicy. Instead, with this policy you must pass in "
171+
"--credential-key-header-name."
172+
)
173+
if not credential_key_header_name:
174+
raise ValueError(
175+
"With default credential policy type AzureKeyCredentialPolicy, you must pass in the name "
176+
"of the key header with the flag --credential-key-header-name"
177+
)
178+
return credential_scopes, credential_key_header_name
179+
180+
def _handle_default_authentication_policy(self, azure_arm, credential):
145181

146182
passed_in_credential_default_policy_type = (
147183
self._autorestapi.get_value("credential-default-policy-type") or "BearerTokenCredentialPolicy"
@@ -159,27 +195,27 @@ def _build_code_model_options(self) -> Dict[str, Any]:
159195
"BearerTokenCredentialPolicy or AzureKeyCredentialPolicy"
160196
)
161197

162-
if credential_scopes and credential_default_policy_type != "BearerTokenCredentialPolicy":
163-
_LOGGER.warning(
164-
"You have --credential-default-policy-type not set as BearerTokenCredentialPolicy and a value for "
165-
"--credential-scopes. Since credential scopes are tied to the BearerTokenCredentialPolicy, "
166-
"we will ignore your credential scopes."
198+
credential_scopes, credential_key_header_name = self._get_credential_param(
199+
azure_arm, credential, credential_default_policy_type
200+
)
201+
202+
return credential_default_policy_type, credential_scopes, credential_key_header_name
203+
204+
205+
def _build_code_model_options(self) -> Dict[str, Any]:
206+
"""Build en options dict from the user input while running autorest.
207+
"""
208+
azure_arm = self._autorestapi.get_boolean_value("azure-arm", False)
209+
credential = (
210+
self._autorestapi.get_boolean_value("add-credentials", False) or
211+
self._autorestapi.get_boolean_value("add-credential", False)
212+
)
213+
214+
credential_default_policy_type, credential_scopes, credential_key_header_name = (
215+
self._handle_default_authentication_policy(
216+
azure_arm, credential
167217
)
168-
credential_scopes = []
169-
170-
elif not credential_scopes and credential_default_policy_type == "BearerTokenCredentialPolicy":
171-
if azure_arm:
172-
credential_scopes = ["https://management.azure.com/.default"]
173-
elif credential:
174-
# If add-credential is specified, we still want to add a credential_scopes variable.
175-
# Will make it an empty list so we can differentiate between this case and None
176-
_LOGGER.warning(
177-
"You have used the --add-credential flag but not the --credential-scopes flag "
178-
"while generating non-management plane code. "
179-
"This is not recommend because it forces the customer to pass credential scopes "
180-
"through kwargs if they want to authenticate."
181-
)
182-
credential_scopes = []
218+
)
183219

184220

185221
license_header = self._autorestapi.get_value("header-text")
@@ -194,6 +230,7 @@ def _build_code_model_options(self) -> Dict[str, Any]:
194230
"azure_arm": azure_arm,
195231
"credential": credential,
196232
"credential_scopes": credential_scopes,
233+
"credential_key_header_name": credential_key_header_name,
197234
"head_as_boolean": self._autorestapi.get_boolean_value("head-as-boolean", False),
198235
"license_header": license_header,
199236
"keep_version_file": self._autorestapi.get_boolean_value("keep-version-file", False),

autorest/codegen/templates/config.py.jinja2

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,6 @@ class {{ code_model.class_name }}Configuration(Configuration):
8888
{% endif %}
8989
if self.credential and not self.authentication_policy:
9090
{% set credential_default_policy_type = ("Async" if (async_mode and code_model.options['credential_default_policy_type_has_async_version']) else "") + code_model.options['credential_default_policy_type'] %}
91-
{% set bearer_token_specific_params = "*self.credential_scopes, " %}
92-
self.authentication_policy = policies.{{ credential_default_policy_type }}(self.credential, {{ bearer_token_specific_params if "BearerTokenCredentialPolicy" in credential_default_policy_type }}**kwargs)
91+
{% set credential_param_type = ("'" + code_model.options['credential_key_header_name'] + "', ") if code_model.options['credential_key_header_name'] else ("*self.credential_scopes, " if "BearerTokenCredentialPolicy" in credential_default_policy_type else "") %}
92+
self.authentication_policy = policies.{{ credential_default_policy_type }}(self.credential, {{ credential_param_type if credential_param_type }}**kwargs)
9393
{% endif %}

autorest/codegen/templates/metadata.json.jinja2

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@
4646
"credential": {{ code_model.options['credential'] | tojson }},
4747
"credential_scopes": {{ code_model.options['credential_scopes'] | tojson }},
4848
"credential_default_policy_type": {{ code_model.options['credential_default_policy_type'] | tojson }},
49-
"credential_default_policy_type_has_async_version": {{ code_model.options['credential_default_policy_type_has_async_version'] | tojson }}
49+
"credential_default_policy_type_has_async_version": {{ code_model.options['credential_default_policy_type_has_async_version'] | tojson }},
50+
"credential_key_header_name": {{ code_model.options['credential_key_header_name'] | tojson }}
5051
},
5152
"operation_groups": {
5253
{% for operation_group in code_model.operation_groups %}

autorest/multiapi/templates/multiapi_config.py.jinja2

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,6 @@ class {{ client_name }}Configuration(Configuration):
103103
{% endif %}
104104
if self.credential and not self.authentication_policy:
105105
{% set credential_default_policy_type = ("Async" if (async_mode and config['credential_default_policy_type_has_async_version']) else "") + config['credential_default_policy_type'] %}
106-
{% set bearer_token_specific_params = "*self.credential_scopes, " %}
107-
self.authentication_policy = policies.{{ credential_default_policy_type }}(self.credential, {{ bearer_token_specific_params if "BearerTokenCredentialPolicy" in credential_default_policy_type }}**kwargs)
106+
{% set credential_param_type = ("'" + config['credential_key_header_name'] + "', ") if config['credential_key_header_name'] else ("*self.credential_scopes, " if "BearerTokenCredentialPolicy" in credential_default_policy_type else "") %}
107+
self.authentication_policy = policies.{{ credential_default_policy_type }}(self.credential, {{ credential_param_type if credential_param_type }}**kwargs)
108108
{% endif %}

tasks.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ def regen_expected(c, opts, debug):
149149
args.append(f"--override-info.description={opts['override-info.description']}")
150150
if opts.get('credential-default-policy-type'):
151151
args.append(f"--credential-default-policy-type={opts['credential-default-policy-type']}")
152+
if opts.get('credential-key-header-name'):
153+
args.append(f"--credential-key-header-name={opts['credential-key-header-name']}")
152154
if opts.get('package-name'):
153155
args.append(f"--package-name={opts['package-name']}")
154156
if opts.get('override-client-name'):
@@ -262,7 +264,8 @@ def regenerate_credential_default_policy(c, debug=False):
262264
'azure_arm': True,
263265
'flattening_threshold': '1',
264266
'ns_prefix': True,
265-
'credential-default-policy-type': 'AzureKeyCredentialPolicy'
267+
'credential-default-policy-type': 'AzureKeyCredentialPolicy',
268+
'credential-key-header-name': 'Authorization'
266269
}
267270
regen_expected(c, opts, debug)
268271

test/azure/Expected/AcceptanceTests/HeadWithAzureKeyCredentialPolicy/headwithazurekeycredentialpolicy/_configuration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,4 +60,4 @@ def _configure(
6060
self.redirect_policy = kwargs.get('redirect_policy') or policies.RedirectPolicy(**kwargs)
6161
self.authentication_policy = kwargs.get('authentication_policy')
6262
if self.credential and not self.authentication_policy:
63-
self.authentication_policy = policies.AzureKeyCredentialPolicy(self.credential, **kwargs)
63+
self.authentication_policy = policies.AzureKeyCredentialPolicy(self.credential, 'Authorization', **kwargs)

test/azure/Expected/AcceptanceTests/HeadWithAzureKeyCredentialPolicy/headwithazurekeycredentialpolicy/aio/_configuration_async.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,4 @@ def _configure(
5656
self.redirect_policy = kwargs.get('redirect_policy') or policies.AsyncRedirectPolicy(**kwargs)
5757
self.authentication_policy = kwargs.get('authentication_policy')
5858
if self.credential and not self.authentication_policy:
59-
self.authentication_policy = policies.AzureKeyCredentialPolicy(self.credential, **kwargs)
59+
self.authentication_policy = policies.AzureKeyCredentialPolicy(self.credential, 'Authorization', **kwargs)

test/multiapi/AcceptanceTests/asynctests/test_multiapi_credential_default_policy.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@ async def default_client(credential, authentication_policy):
3333
from multiapicredentialdefaultpolicy.aio import MultiapiServiceClient
3434
async with MultiapiServiceClient(
3535
base_url="http://localhost:3000",
36-
credential="12345",
37-
name="azure_key_credential_policy"
36+
credential="12345"
3837
) as default_client:
3938
await yield_(default_client)
4039

4140
def test_multiapi_credential_default_policy_type(default_client):
4241
# making sure that the authentication policy is AzureKeyCredentialPolicy
43-
assert isinstance(default_client._config.authentication_policy, AzureKeyCredentialPolicy)
42+
assert isinstance(default_client._config.authentication_policy, AzureKeyCredentialPolicy)
43+
assert default_client._config.authentication_policy._name == "Authorization"

test/multiapi/AcceptanceTests/test_multiapi_credential_default_policy.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@ def default_client(authentication_policy):
3131
from multiapicredentialdefaultpolicy import MultiapiServiceClient
3232
with MultiapiServiceClient(
3333
base_url="http://localhost:3000",
34-
credential="12345",
35-
name="azure_key_credential_policy"
34+
credential="12345"
3635
) as default_client:
3736
yield default_client
3837

3938
def test_multiapi_credential_default_policy_type(default_client):
4039
# making sure that the authentication policy is AzureKeyCredentialPolicy
41-
assert isinstance(default_client._config.authentication_policy, AzureKeyCredentialPolicy)
40+
assert isinstance(default_client._config.authentication_policy, AzureKeyCredentialPolicy)
41+
assert default_client._config.authentication_policy._name == "Authorization"

0 commit comments

Comments
 (0)