diff --git a/src/azure-cli-core/azure/cli/core/_profile.py b/src/azure-cli-core/azure/cli/core/_profile.py index 9877ec545ad..f7223264eea 100644 --- a/src/azure-cli-core/azure/cli/core/_profile.py +++ b/src/azure-cli-core/azure/cli/core/_profile.py @@ -360,17 +360,20 @@ def get_raw_token(self, resource=None, scopes=None, subscription=None, tenant=No managed_identity_type, managed_identity_id = Profile._parse_managed_identity_account(account) + non_current_tenant_template = ("For {} account, getting access token for non-current tenants is not " + "supported. The specified tenant must be the current tenant " + f"{account[_TENANT_ID]}") if in_cloud_console() and account[_USER_ENTITY].get(_CLOUD_SHELL_ID): # Cloud Shell - if tenant: - raise CLIError("Tenant shouldn't be specified for Cloud Shell account") + if tenant and tenant != account[_TENANT_ID]: + raise CLIError(non_current_tenant_template.format('Cloud Shell')) from .auth.msal_credentials import CloudShellCredential cred = CloudShellCredential() elif managed_identity_type: # managed identity - if tenant: - raise CLIError("Tenant shouldn't be specified for managed identity account") + if tenant and tenant != account[_TENANT_ID]: + raise CLIError(non_current_tenant_template.format('managed identity')) cred = ManagedIdentityAuth.credential_factory(managed_identity_type, managed_identity_id) if credential_out: credential_out['credential'] = cred diff --git a/src/azure-cli-core/azure/cli/core/tests/test_profile.py b/src/azure-cli-core/azure/cli/core/tests/test_profile.py index 061954f1766..0f899cebacb 100644 --- a/src/azure-cli-core/azure/cli/core/tests/test_profile.py +++ b/src/azure-cli-core/azure/cli/core/tests/test_profile.py @@ -1134,9 +1134,15 @@ def test_get_raw_token_mi_system_assigned(self): self.assertEqual(subscription_id, self.test_mi_subscription_id) self.assertEqual(tenant_id, self.test_mi_tenant) - # verify tenant shouldn't be specified for MSI account - with self.assertRaisesRegex(CLIError, "Tenant shouldn't be specified"): - cred, subscription_id, _ = profile.get_raw_token(resource='http://test_resource', tenant=self.tenant_id) + # Specifying the current tenant is allowed + cred, subscription_id, tenant_id = profile.get_raw_token(tenant=self.test_mi_tenant) + self.assertEqual(tenant_id, self.test_mi_tenant) + + # Specifying a non-current tenant is disallowed + with self.assertRaisesRegex(CLIError, + "For managed identity account, getting access token for non-current tenants is " + "not supported"): + profile.get_raw_token(tenant='another-tenant') @mock.patch('azure.cli.core.auth.util.now_timestamp', new=now_timestamp_mock) @mock.patch('azure.cli.core.auth.msal_credentials.ManagedIdentityCredential', ManagedIdentityCredentialStub) @@ -1285,9 +1291,15 @@ def cloud_shell_credential_factory(): self.assertEqual(subscription_id, test_subscription_id) self.assertEqual(tenant_id, test_tenant_id) - # Verify tenant shouldn't be specified for Cloud Shell account - with self.assertRaisesRegex(CLIError, 'Cloud Shell'): - profile.get_raw_token(resource='http://test_resource', tenant=self.tenant_id) + # Specifying the current tenant is allowed + cred, subscription_id, tenant_id = profile.get_raw_token(tenant=test_tenant_id) + self.assertEqual(tenant_id, test_tenant_id) + + # Specifying a non-current tenant is disallowed + with self.assertRaisesRegex(CLIError, + "For Cloud Shell account, getting access token for non-current tenants is " + "not supported"): + profile.get_raw_token(tenant='another-tenant') @mock.patch('azure.cli.core.auth.identity.Identity.get_user_credential') def test_get_msal_token(self, get_user_credential_mock):