diff --git a/src/azure-cli-core/azure/cli/core/_profile.py b/src/azure-cli-core/azure/cli/core/_profile.py index cb9b67a3ad6..48efd8cd557 100644 --- a/src/azure-cli-core/azure/cli/core/_profile.py +++ b/src/azure-cli-core/azure/cli/core/_profile.py @@ -608,43 +608,49 @@ def get_msal_token(self, scopes, data): This is added only for vmssh feature. It is a temporary solution and will deprecate after MSAL adopted completely. """ - from msal import ClientApplication - import posixpath account = self.get_subscription() - username = account[_USER_ENTITY][_USER_NAME] - tenant = account[_TENANT_ID] or 'common' - _, refresh_token, _, _ = self.get_refresh_token() + identity_type = account[_USER_ENTITY][_USER_TYPE] + username_or_sp_id = account[_USER_ENTITY][_USER_NAME] + tenant = account[_TENANT_ID] + + import posixpath authority = posixpath.join(self.cli_ctx.cloud.endpoints.active_directory, tenant) - app = ClientApplication(_CLIENT_ID, authority=authority) - result = app.acquire_token_by_refresh_token(refresh_token, scopes, data=data) - if 'error' in result: - logger.warning(result['error_description']) + if identity_type == _USER: + # Use ARM as resource to get the refresh token from ADAL token cache + resource = self.cli_ctx.cloud.endpoints.active_directory_resource_id + _, _, token_entry = self._creds_cache.retrieve_token_for_user( + username_or_sp_id, account[_TENANT_ID], resource) + refresh_token = token_entry.get(_REFRESH_TOKEN) - # Retry login with VM SSH as resource - token_entry = self._login_with_authorization_code_flow( - tenant, 'https://pas.windows.net/CheckMyAccess/Linux') - result = app.acquire_token_by_refresh_token(token_entry['refreshToken'], scopes, data=data) + from azure.cli.core.msal_authentication import UserCredential + cred = UserCredential(_CLIENT_ID, authority=authority) + result = cred.acquire_token_by_refresh_token(refresh_token, scopes, data=data) + # In case of being rejected by Conditional Access, launch browser automatically to retry + # with VM SSH as resource. if 'error' in result: - from azure.cli.core.adal_authentication import aad_error_handler - aad_error_handler(result) - return username, result["access_token"] + logger.warning(result['error_description']) - def get_refresh_token(self, resource=None, - subscription=None): - account = self.get_subscription(subscription) - user_type = account[_USER_ENTITY][_USER_TYPE] - username_or_sp_id = account[_USER_ENTITY][_USER_NAME] - resource = resource or self.cli_ctx.cloud.endpoints.active_directory_resource_id + from azure.cli.core.util import scopes_to_resource + token_entry = self._login_with_authorization_code_flow(tenant, scopes_to_resource(scopes)) + result = cred.acquire_token_by_refresh_token(token_entry['refreshToken'], scopes, data=data) - if user_type == _USER: - _, _, token_entry = self._creds_cache.retrieve_token_for_user( - username_or_sp_id, account[_TENANT_ID], resource) - return None, token_entry.get(_REFRESH_TOKEN), token_entry[_ACCESS_TOKEN], str(account[_TENANT_ID]) + elif identity_type == _SERVICE_PRINCIPAL: + from azure.cli.core.msal_authentication import ServicePrincipalCredential + + sp_id = username_or_sp_id + sp_credential = self._creds_cache.retrieve_cred_for_service_principal(sp_id) + cred = ServicePrincipalCredential(sp_id, secret_or_certificate=sp_credential, authority=authority) + result = cred.get_token(scopes=scopes, data=data) + else: + raise CLIError("Identity type {} is currently unsupported".format(identity_type)) + + if 'error' in result: + from azure.cli.core.adal_authentication import aad_error_handler + aad_error_handler(result) - sp_secret = self._creds_cache.retrieve_cred_for_service_principal(username_or_sp_id) - return username_or_sp_id, sp_secret, None, str(account[_TENANT_ID]) + return username_or_sp_id, result["access_token"] def get_raw_token(self, resource=None, subscription=None, tenant=None): logger.debug("Profile.get_raw_token invoked with resource=%r, subscription=%r, tenant=%r", diff --git a/src/azure-cli-core/azure/cli/core/msal_authentication.py b/src/azure-cli-core/azure/cli/core/msal_authentication.py new file mode 100644 index 00000000000..ffaaba0d927 --- /dev/null +++ b/src/azure-cli-core/azure/cli/core/msal_authentication.py @@ -0,0 +1,50 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +""" +Credentials defined in this module are alternative implementations of credentials provided by Azure Identity. + +These credentials implements azure.core.credentials.TokenCredential by exposing get_token method for Track 2 +SDK invocation. +""" + +import os + +from knack.log import get_logger +from msal import PublicClientApplication, ConfidentialClientApplication + +logger = get_logger(__name__) + + +class UserCredential(PublicClientApplication): + + def get_token(self, scopes, **kwargs): + raise NotImplementedError + + +class ServicePrincipalCredential(ConfidentialClientApplication): + + def __init__(self, client_id, secret_or_certificate=None, **kwargs): + + # If certificate file path is provided, transfer it to MSAL input + if os.path.isfile(secret_or_certificate): + cert_file = secret_or_certificate + with open(cert_file, 'r') as f: + cert_str = f.read() + + # Compute the thumbprint + from OpenSSL.crypto import load_certificate, FILETYPE_PEM + cert = load_certificate(FILETYPE_PEM, cert_str) + thumbprint = cert.digest("sha1").decode().replace(' ', '').replace(':', '') + + client_credential = {"private_key": cert_str, "thumbprint": thumbprint} + else: + client_credential = secret_or_certificate + + super().__init__(client_id, client_credential=client_credential, **kwargs) + + def get_token(self, scopes, **kwargs): + logger.debug("ServicePrincipalCredential.get_token: scopes=%r, kwargs=%r", scopes, kwargs) + return self.acquire_token_for_client(scopes=scopes, **kwargs) diff --git a/src/azure-cli-core/azure/cli/core/tests/test_profile.py b/src/azure-cli-core/azure/cli/core/tests/test_profile.py index 1ff0e4cbb95..0bb12e3208e 100644 --- a/src/azure-cli-core/azure/cli/core/tests/test_profile.py +++ b/src/azure-cli-core/azure/cli/core/tests/test_profile.py @@ -1939,9 +1939,14 @@ def test_find_using_specific_tenant(self, _get_authorization_code_mock, mock_aut self.assertEqual(all_subscriptions[0].tenant_id, token_tenant) self.assertEqual(all_subscriptions[0].home_tenant_id, home_tenant) - @mock.patch('azure.cli.core._profile.CredsCache.retrieve_token_for_user', autospec=True) + @mock.patch('msal.ConfidentialClientApplication.acquire_token_for_client', autospec=True) + @mock.patch('azure.cli.core._profile.CredsCache.retrieve_cred_for_service_principal', autospec=True) @mock.patch('msal.ClientApplication.acquire_token_by_refresh_token', autospec=True) - def test_get_msal_token(self, mock_acquire_token, mock_retrieve_token_for_user): + @mock.patch('azure.cli.core._profile.CredsCache.retrieve_token_for_user', autospec=True) + @mock.patch('azure.cli.core._profile.Profile.get_subscription', autospec=True) + def test_get_msal_token(self, get_subscription_mock, retrieve_token_for_user_mock, + acquire_token_by_refresh_token_mock, retrieve_cred_for_service_principal_mock, + acquire_token_for_client_mock): """ This is added only for vmssh feature. It is a temporary solution and will deprecate after MSAL adopted completely. @@ -1950,15 +1955,27 @@ def test_get_msal_token(self, mock_acquire_token, mock_retrieve_token_for_user): storage_mock = {'subscriptions': None} profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - consolidated = profile._normalize_properties(self.user1, [self.subscription1], False) - profile._set_subscriptions(consolidated) + msal_result = { + 'token_type': 'ssh-cert', + 'scope': 'https://pas.windows.net/CheckMyAccess/Linux/user_impersonation https://pas.windows.net/CheckMyAccess/Linux/.default', + 'expires_in': 3599, + 'ext_expires_in': 3599, + 'access_token': 'fake_cert' + } - some_token_type = 'Bearer' - mock_retrieve_token_for_user.return_value = (some_token_type, TestProfile.raw_token1, TestProfile.token_entry1) - mock_acquire_token.return_value = { - 'access_token': 'fake_access_token' + # User + get_subscription_mock.return_value = { + 'tenantId': self.tenant_id, + 'user': { + 'name': self.user1, + 'type': 'user' + }, } - scopes = ["https://pas.windows.net/CheckMyAccess/Linux/user_impersonation"] + + retrieve_token_for_user_mock.return_value = ('Bearer', self.raw_token1, self.token_entry1) + acquire_token_by_refresh_token_mock.return_value = msal_result + + scopes = ["https://pas.windows.net/CheckMyAccess/Linux/.default"] data = { "token_type": "ssh-cert", "req_cnf": "fake_jwk", @@ -1966,7 +1983,22 @@ def test_get_msal_token(self, mock_acquire_token, mock_retrieve_token_for_user): } username, access_token = profile.get_msal_token(scopes, data) self.assertEqual(username, self.user1) - self.assertEqual(access_token, 'fake_access_token') + self.assertEqual(access_token, 'fake_cert') + + # Service Principal + sp_id = '610a3200-0000-0000-0000-000000000000' + get_subscription_mock.return_value = { + 'tenantId': self.tenant_id, + 'user': { + 'name': sp_id, + 'type': 'servicePrincipal' + }, + } + retrieve_cred_for_service_principal_mock.return_value = "some_secret" + acquire_token_for_client_mock.return_value = msal_result + username, access_token = profile.get_msal_token(scopes, data) + self.assertEqual(username, sp_id) + self.assertEqual(access_token, 'fake_cert') class FileHandleStub(object): # pylint: disable=too-few-public-methods diff --git a/src/azure-cli-core/azure/cli/core/tests/test_util.py b/src/azure-cli-core/azure/cli/core/tests/test_util.py index 544d21dd9ec..a0724ecf1b0 100644 --- a/src/azure-cli-core/azure/cli/core/tests/test_util.py +++ b/src/azure-cli-core/azure/cli/core/tests/test_util.py @@ -397,6 +397,12 @@ def test_scopes_to_resource(self): self.assertEqual(scopes_to_resource(('https://managedhsm.azure.com/.default',)), 'https://managedhsm.azure.com') + # VM SSH + self.assertEqual(scopes_to_resource(["https://pas.windows.net/CheckMyAccess/Linux/.default"]), + 'https://pas.windows.net/CheckMyAccess/Linux') + self.assertEqual(scopes_to_resource(["https://pas.windows.net/CheckMyAccess/Linux/user_impersonation"]), + 'https://pas.windows.net/CheckMyAccess/Linux') + def test_resource_to_scopes(self): from azure.cli.core.util import resource_to_scopes # resource converted to a scopes list diff --git a/src/azure-cli-core/azure/cli/core/util.py b/src/azure-cli-core/azure/cli/core/util.py index af130a175c2..29f9e5edccd 100644 --- a/src/azure-cli-core/azure/cli/core/util.py +++ b/src/azure-cli-core/azure/cli/core/util.py @@ -1222,8 +1222,12 @@ def scopes_to_resource(scopes): :rtype: str """ scope = scopes[0] - if scope.endswith("/.default"): - scope = scope[:-len("/.default")] + + suffixes = ['/.default', '/user_impersonation'] + + for s in suffixes: + if scope.endswith(s): + return scope[:-len(s)] return scope