Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions src/azure-cli-core/azure/cli/core/_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,10 @@ def _try_parse_msi_account_name(account):
return parts[0], (None if len(parts) <= 1 else parts[1])
return None, None

def get_login_credentials(self, resource=None, subscription_id=None, aux_subscriptions=None):
def get_login_credentials(self, resource=None, subscription_id=None, aux_subscriptions=None, aux_tenants=None):
if aux_tenants and aux_subscriptions:
raise CLIError("Please specify only one of aux_subscriptions and aux_tenants, not both")

account = self.get_subscription(subscription_id)
user_type = account[_USER_ENTITY][_USER_TYPE]
username_or_sp_id = account[_USER_ENTITY][_USER_NAME]
Expand All @@ -543,12 +546,14 @@ def get_login_credentials(self, resource=None, subscription_id=None, aux_subscri
identity_type, identity_id = Profile._try_parse_msi_account_name(account)

external_tenants_info = []
ext_subs = [aux_sub for aux_sub in (aux_subscriptions or []) if aux_sub != subscription_id]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or []) [](start = 61, length = 7)

if both ext_sub and ext_tenants are empty, the original code append sub = self.get_subscription([]), your code removed this logic, please check if it will cause regression

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the original code, when aux_subscriptions is empty, ext_subs is [], sub = self.get_subscription (ext_sub) will not be executed.
I understand that the previous code was written this way to prevent aux_subscriptions from having null pointer problems.

for ext_sub in ext_subs:
sub = self.get_subscription(ext_sub)
if sub[_TENANT_ID] != account[_TENANT_ID]:
# external_tenants_info.append((sub[_USER_ENTITY][_USER_NAME], sub[_TENANT_ID]))
external_tenants_info.append(sub)
if aux_tenants:
external_tenants_info = [tenant for tenant in aux_tenants if tenant != account[_TENANT_ID]]
if aux_subscriptions:
ext_subs = [aux_sub for aux_sub in aux_subscriptions if aux_sub != subscription_id]
for ext_sub in ext_subs:
sub = self.get_subscription(ext_sub)
if sub[_TENANT_ID] != account[_TENANT_ID]:
external_tenants_info.append(sub[_TENANT_ID])

if identity_type is None:
def _retrieve_token():
Expand All @@ -564,13 +569,13 @@ def _retrieve_token():

def _retrieve_tokens_from_external_tenants():
external_tokens = []
for s in external_tenants_info:
for sub_tenant_id in external_tenants_info:
if user_type == _USER:
external_tokens.append(self._creds_cache.retrieve_token_for_user(
username_or_sp_id, s[_TENANT_ID], resource))
username_or_sp_id, sub_tenant_id, resource))
else:
external_tokens.append(self._creds_cache.retrieve_token_for_service_principal(
username_or_sp_id, resource, s[_TENANT_ID], resource))
username_or_sp_id, resource, sub_tenant_id, resource))
return external_tokens

from azure.cli.core.adal_authentication import AdalAuthentication
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def resolve_client_arg_name(operation, kwargs):


def get_mgmt_service_client(cli_ctx, client_or_resource_type, subscription_id=None, api_version=None,
aux_subscriptions=None, **kwargs):
aux_subscriptions=None, aux_tenants=None, **kwargs):
"""
:params subscription_id: the current account's subscription
:param aux_subscriptions: mainly for cross tenant scenarios, say vnet peering.
Expand All @@ -66,6 +66,7 @@ def get_mgmt_service_client(cli_ctx, client_or_resource_type, subscription_id=No
client, _ = _get_mgmt_service_client(cli_ctx, client_type, subscription_id=subscription_id,
api_version=api_version, sdk_profile=sdk_profile,
aux_subscriptions=aux_subscriptions,
aux_tenants=aux_tenants,
**kwargs)
return client

Expand Down Expand Up @@ -118,13 +119,15 @@ def _get_mgmt_service_client(cli_ctx,
resource=None,
sdk_profile=None,
aux_subscriptions=None,
aux_tenants=None,
**kwargs):
from azure.cli.core._profile import Profile
logger.debug('Getting management service client client_type=%s', client_type.__name__)
resource = resource or cli_ctx.cloud.endpoints.active_directory_resource_id
profile = Profile(cli_ctx=cli_ctx)
cred, subscription_id, _ = profile.get_login_credentials(subscription_id=subscription_id, resource=resource,
aux_subscriptions=aux_subscriptions)
aux_subscriptions=aux_subscriptions,
aux_tenants=aux_tenants)

client_kwargs = {}
if base_url_bound:
Expand Down
54 changes: 54 additions & 0 deletions src/azure-cli-core/azure/cli/core/tests/test_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,60 @@ def test_get_login_credentials_aux_subscriptions(self, mock_get_token, mock_read

self.assertEqual(mock_get_token.call_count, 2)

@mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True)
@mock.patch('azure.cli.core._profile.CredsCache.retrieve_token_for_user', autospec=True)
def test_get_login_credentials_aux_tenants(self, mock_get_token, mock_read_cred_file):
cli = DummyCli()
raw_token2 = 'some...secrets2'
token_entry2 = {
"resource": "https://management.core.windows.net/",
"tokenType": "Bearer",
"_authority": "https://login.microsoftonline.com/common",
"accessToken": raw_token2,
}
some_token_type = 'Bearer'
mock_read_cred_file.return_value = [TestProfile.token_entry1, token_entry2]
mock_get_token.side_effect = [(some_token_type, TestProfile.raw_token1), (some_token_type, raw_token2)]
# setup
storage_mock = {'subscriptions': None}
profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False)
test_subscription_id = '12345678-1bf0-4dda-aec3-cb9272f09590'
test_subscription_id2 = '12345678-1bf0-4dda-aec3-cb9272f09591'
test_tenant_id = '12345678-38d6-4fb2-bad9-b7b93a3e1234'
test_tenant_id2 = '12345678-38d6-4fb2-bad9-b7b93a3e4321'
test_subscription = SubscriptionStub('/subscriptions/{}'.format(test_subscription_id),
'MSI-DEV-INC', self.state1, test_tenant_id)
test_subscription2 = SubscriptionStub('/subscriptions/{}'.format(test_subscription_id2),
'MSI-DEV-INC2', self.state1, test_tenant_id2)
consolidated = profile._normalize_properties(self.user1,
[test_subscription, test_subscription2],
False)
profile._set_subscriptions(consolidated)
# test only input aux_tenants
cred, subscription_id, _ = profile.get_login_credentials(subscription_id=test_subscription_id,
aux_tenants=[test_tenant_id2])

# verify
self.assertEqual(subscription_id, test_subscription_id)

# verify the cred._tokenRetriever is a working lambda
token_type, token = cred._token_retriever()
self.assertEqual(token, self.raw_token1)
self.assertEqual(some_token_type, token_type)

token2 = cred._external_tenant_token_retriever()
self.assertEqual(len(token2), 1)
self.assertEqual(token2[0][1], raw_token2)

self.assertEqual(mock_get_token.call_count, 2)

# test input aux_tenants and aux_subscriptions
with self.assertRaisesRegexp(CLIError,
"Please specify only one of aux_subscriptions and aux_tenants, not both"):
cred, subscription_id, _ = profile.get_login_credentials(subscription_id=test_subscription_id,
aux_subscriptions=[test_subscription_id2],
aux_tenants=[test_tenant_id2])

@mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True)
@mock.patch('msrestazure.azure_active_directory.MSIAuthentication', autospec=True)
def test_get_login_credentials_msi_system_assigned(self, mock_msi_auth, mock_read_cred_file):
Expand Down
14 changes: 10 additions & 4 deletions src/azure-cli/azure/cli/command_modules/resource/_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,11 @@ def load_arguments(self, _):
with self.argument_context('group deployment create') as c:
c.argument('deployment_name', arg_type=deployment_create_name_type)
c.argument('handle_extended_json_format', arg_type=extended_json_format_type)
c.argument('aux_subscriptions', nargs='*', options_list=['--aux-subs'],
help='Auxiliary subscriptions which will be used during deployment across tenants.')
c.argument('aux_subscriptions', nargs='+', options_list=['--aux-subs'],
help='Auxiliary subscriptions which will be used during deployment across tenants.',
deprecate_info=c.deprecate(target='--aux-subs', redirect='--aux-tenants'))
c.argument('aux_tenants', nargs='+', options_list=['--aux-tenants'],
help='Auxiliary tenants which will be used during deployment across tenants.')

with self.argument_context('group deployment validate') as c:
c.argument('handle_extended_json_format', arg_type=extended_json_format_type)
Expand Down Expand Up @@ -238,8 +241,11 @@ def load_arguments(self, _):
with self.argument_context('deployment group create') as c:
c.argument('deployment_name', arg_type=deployment_create_name_type)
c.argument('handle_extended_json_format', arg_type=extended_json_format_type)
c.argument('aux_subscriptions', nargs='*', options_list=['--aux-subs'],
help='Auxiliary subscriptions which will be used during deployment across tenants.')
c.argument('aux_subscriptions', nargs='+', options_list=['--aux-subs'],
help='Auxiliary subscriptions which will be used during deployment across tenants.',
deprecate_info=c.deprecate(target='--aux-subs', redirect='--aux-tenants'))
c.argument('aux_tenants', nargs='+', options_list=['--aux-tenants'],
help='Auxiliary tenants which will be used during deployment across tenants.')

with self.argument_context('deployment group validate') as c:
c.argument('deployment_name', arg_type=deployment_create_name_type)
Expand Down
31 changes: 18 additions & 13 deletions src/azure-cli/azure/cli/command_modules/resource/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def _urlretrieve(url):
def _deploy_arm_template_core(cli_ctx, resource_group_name,
template_file=None, template_uri=None, deployment_name=None,
parameters=None, mode=None, rollback_on_error=None, validate_only=False,
no_wait=False, aux_subscriptions=None):
no_wait=False, aux_subscriptions=None, aux_tenants=None):
DeploymentProperties, TemplateLink, OnErrorDeployment = get_sdk(cli_ctx, ResourceType.MGMT_RESOURCE_RESOURCES,
'DeploymentProperties', 'TemplateLink',
'OnErrorDeployment', mod='models')
Expand Down Expand Up @@ -287,7 +287,8 @@ def _deploy_arm_template_core(cli_ctx, resource_group_name,
properties = DeploymentProperties(template=template, template_link=template_link,
parameters=parameters, mode=mode, on_error_deployment=on_error_deployment)

smc = get_mgmt_service_client(cli_ctx, ResourceType.MGMT_RESOURCE_RESOURCES, aux_subscriptions=aux_subscriptions)
smc = get_mgmt_service_client(cli_ctx, ResourceType.MGMT_RESOURCE_RESOURCES, aux_subscriptions=aux_subscriptions,
aux_tenants=aux_tenants)

validation_result = smc.deployments.validate(resource_group_name=resource_group_name, deployment_name=deployment_name, properties=properties)

Expand All @@ -312,7 +313,7 @@ def _remove_comments_from_json(template):
def _deploy_arm_template_core_unmodified(cli_ctx, resource_group_name, template_file=None,
template_uri=None, deployment_name=None, parameters=None,
mode=None, rollback_on_error=None, validate_only=False, no_wait=False,
aux_subscriptions=None):
aux_subscriptions=None, aux_tenants=None):
DeploymentProperties, TemplateLink, OnErrorDeployment = get_sdk(cli_ctx, ResourceType.MGMT_RESOURCE_RESOURCES,
'DeploymentProperties', 'TemplateLink',
'OnErrorDeployment', mod='models')
Expand Down Expand Up @@ -343,7 +344,8 @@ def _deploy_arm_template_core_unmodified(cli_ctx, resource_group_name, template_
properties = DeploymentProperties(template=template_content, template_link=template_link,
parameters=parameters, mode=mode, on_error_deployment=on_error_deployment)

smc = get_mgmt_service_client(cli_ctx, ResourceType.MGMT_RESOURCE_RESOURCES, aux_subscriptions=aux_subscriptions)
smc = get_mgmt_service_client(cli_ctx, ResourceType.MGMT_RESOURCE_RESOURCES, aux_subscriptions=aux_subscriptions,
aux_tenants=aux_tenants)

deployment_client = smc.deployments # This solves the multi-api for you

Expand Down Expand Up @@ -473,14 +475,14 @@ def deploy_arm_template_at_resource_group(cmd,
template_file=None, template_uri=None, parameters=None,
deployment_name=None, mode=None, rollback_on_error=None,
no_wait=False, handle_extended_json_format=False,
aux_subscriptions=None):
aux_subscriptions=None, aux_tenants=None):
return _deploy_arm_template_at_resource_group(cli_ctx=cmd.cli_ctx,
resource_group_name=resource_group_name,
template_file=template_file, template_uri=template_uri, parameters=parameters,
deployment_name=deployment_name, mode=mode, rollback_on_error=rollback_on_error,
validate_only=False,
no_wait=no_wait, handle_extended_json_format=handle_extended_json_format,
aux_subscriptions=aux_subscriptions)
aux_subscriptions=aux_subscriptions, aux_tenants=aux_tenants)


def validate_arm_template_at_resource_group(cmd,
Expand All @@ -502,7 +504,7 @@ def _deploy_arm_template_at_resource_group(cli_ctx,
deployment_name=None, mode=None, rollback_on_error=None,
validate_only=False,
no_wait=False, handle_extended_json_format=False,
aux_subscriptions=None):
aux_subscriptions=None, aux_tenants=None):
deployment_properties = None
if handle_extended_json_format:
deployment_properties = _prepare_deployment_properties_unmodified(cli_ctx=cli_ctx, template_file=template_file,
Expand All @@ -515,7 +517,8 @@ def _deploy_arm_template_at_resource_group(cli_ctx,
parameters=parameters, mode=mode,
rollback_on_error=rollback_on_error)

mgmt_client = _get_deployment_management_client(cli_ctx, handle_extended_json_format=handle_extended_json_format, aux_subscriptions=aux_subscriptions)
mgmt_client = _get_deployment_management_client(cli_ctx, handle_extended_json_format=handle_extended_json_format,
aux_subscriptions=aux_subscriptions, aux_tenants=aux_tenants)

validation_result = mgmt_client.validate(resource_group_name=resource_group_name, deployment_name=deployment_name, properties=deployment_properties)

Expand Down Expand Up @@ -702,8 +705,10 @@ def _prepare_deployment_properties(cli_ctx, template_file=None, template_uri=Non
return properties


def _get_deployment_management_client(cli_ctx, handle_extended_json_format=False, aux_subscriptions=None):
smc = get_mgmt_service_client(cli_ctx, ResourceType.MGMT_RESOURCE_RESOURCES, aux_subscriptions)
def _get_deployment_management_client(cli_ctx, handle_extended_json_format=False,
aux_subscriptions=None, aux_tenants=None):
smc = get_mgmt_service_client(cli_ctx, ResourceType.MGMT_RESOURCE_RESOURCES, aux_subscriptions=aux_subscriptions,
aux_tenants=aux_tenants)
deployment_client = smc.deployments # This solves the multi-api for you

if handle_extended_json_format:
Expand Down Expand Up @@ -1230,18 +1235,18 @@ def delete_deployment_at_tenant_scope(cmd, deployment_name):
def deploy_arm_template(cmd, resource_group_name,
template_file=None, template_uri=None, deployment_name=None,
parameters=None, mode=None, rollback_on_error=None, no_wait=False,
handle_extended_json_format=False, aux_subscriptions=None):
handle_extended_json_format=False, aux_subscriptions=None, aux_tenants=None):
if handle_extended_json_format:
return _deploy_arm_template_core_unmodified(cmd.cli_ctx, resource_group_name=resource_group_name,
template_file=template_file, template_uri=template_uri,
deployment_name=deployment_name, parameters=parameters, mode=mode,
rollback_on_error=rollback_on_error, no_wait=no_wait,
aux_subscriptions=aux_subscriptions)
aux_subscriptions=aux_subscriptions, aux_tenants=aux_tenants)

return _deploy_arm_template_core(cmd.cli_ctx, resource_group_name=resource_group_name, template_file=template_file,
template_uri=template_uri, deployment_name=deployment_name,
parameters=parameters, mode=mode, rollback_on_error=rollback_on_error,
no_wait=no_wait, aux_subscriptions=aux_subscriptions)
no_wait=no_wait, aux_subscriptions=aux_subscriptions, aux_tenants=aux_tenants)


def validate_arm_template(cmd, resource_group_name, template_file=None, template_uri=None,
Expand Down
Loading