Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 34 additions & 28 deletions src/azure-cli-core/azure/cli/core/_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,43 +608,49 @@ def get_msal_token(self, scopes, data):
This is added only for vmssh feature.
It is a temporary solution and will deprecate after MSAL adopted completely.
"""
from msal import ClientApplication
import posixpath
account = self.get_subscription()
username = account[_USER_ENTITY][_USER_NAME]
tenant = account[_TENANT_ID] or 'common'
_, refresh_token, _, _ = self.get_refresh_token()
identity_type = account[_USER_ENTITY][_USER_TYPE]
username_or_sp_id = account[_USER_ENTITY][_USER_NAME]
tenant = account[_TENANT_ID]

import posixpath
authority = posixpath.join(self.cli_ctx.cloud.endpoints.active_directory, tenant)
app = ClientApplication(_CLIENT_ID, authority=authority)
result = app.acquire_token_by_refresh_token(refresh_token, scopes, data=data)

if 'error' in result:
logger.warning(result['error_description'])
if identity_type == _USER:
# Use ARM as resource to get the refresh token from ADAL token cache
resource = self.cli_ctx.cloud.endpoints.active_directory_resource_id
_, _, token_entry = self._creds_cache.retrieve_token_for_user(
username_or_sp_id, account[_TENANT_ID], resource)
refresh_token = token_entry.get(_REFRESH_TOKEN)

# Retry login with VM SSH as resource
token_entry = self._login_with_authorization_code_flow(
tenant, 'https://pas.windows.net/CheckMyAccess/Linux')
result = app.acquire_token_by_refresh_token(token_entry['refreshToken'], scopes, data=data)
from azure.cli.core.msal_authentication import UserCredential
cred = UserCredential(_CLIENT_ID, authority=authority)
result = cred.acquire_token_by_refresh_token(refresh_token, scopes, data=data)

# In case of being rejected by Conditional Access, launch browser automatically to retry
# with VM SSH as resource.
if 'error' in result:
from azure.cli.core.adal_authentication import aad_error_handler
aad_error_handler(result)
return username, result["access_token"]
logger.warning(result['error_description'])

def get_refresh_token(self, resource=None,
subscription=None):
account = self.get_subscription(subscription)
user_type = account[_USER_ENTITY][_USER_TYPE]
username_or_sp_id = account[_USER_ENTITY][_USER_NAME]
resource = resource or self.cli_ctx.cloud.endpoints.active_directory_resource_id
from azure.cli.core.util import scopes_to_resource
token_entry = self._login_with_authorization_code_flow(tenant, scopes_to_resource(scopes))
result = cred.acquire_token_by_refresh_token(token_entry['refreshToken'], scopes, data=data)

if user_type == _USER:
_, _, token_entry = self._creds_cache.retrieve_token_for_user(
username_or_sp_id, account[_TENANT_ID], resource)
return None, token_entry.get(_REFRESH_TOKEN), token_entry[_ACCESS_TOKEN], str(account[_TENANT_ID])
elif identity_type == _SERVICE_PRINCIPAL:
from azure.cli.core.msal_authentication import ServicePrincipalCredential

sp_id = username_or_sp_id
sp_credential = self._creds_cache.retrieve_cred_for_service_principal(sp_id)
cred = ServicePrincipalCredential(sp_id, secret_or_certificate=sp_credential, authority=authority)
result = cred.get_token(scopes=scopes, data=data)
else:
raise CLIError("Identity type {} is currently unsupported".format(identity_type))

if 'error' in result:
from azure.cli.core.adal_authentication import aad_error_handler
aad_error_handler(result)

sp_secret = self._creds_cache.retrieve_cred_for_service_principal(username_or_sp_id)
return username_or_sp_id, sp_secret, None, str(account[_TENANT_ID])
return username_or_sp_id, result["access_token"]

def get_raw_token(self, resource=None, subscription=None, tenant=None):
logger.debug("Profile.get_raw_token invoked with resource=%r, subscription=%r, tenant=%r",
Expand Down
50 changes: 50 additions & 0 deletions src/azure-cli-core/azure/cli/core/msal_authentication.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------

"""
Credentials defined in this module are alternative implementations of credentials provided by Azure Identity.

These credentials implements azure.core.credentials.TokenCredential by exposing get_token method for Track 2
SDK invocation.
"""

import os

from knack.log import get_logger
from msal import PublicClientApplication, ConfidentialClientApplication

logger = get_logger(__name__)


class UserCredential(PublicClientApplication):

def get_token(self, scopes, **kwargs):
raise NotImplementedError


class ServicePrincipalCredential(ConfidentialClientApplication):

def __init__(self, client_id, secret_or_certificate=None, **kwargs):

# If certificate file path is provided, transfer it to MSAL input
if os.path.isfile(secret_or_certificate):
cert_file = secret_or_certificate
with open(cert_file, 'r') as f:
cert_str = f.read()

# Compute the thumbprint
from OpenSSL.crypto import load_certificate, FILETYPE_PEM
cert = load_certificate(FILETYPE_PEM, cert_str)
thumbprint = cert.digest("sha1").decode().replace(' ', '').replace(':', '')

client_credential = {"private_key": cert_str, "thumbprint": thumbprint}
else:
client_credential = secret_or_certificate

super().__init__(client_id, client_credential=client_credential, **kwargs)

def get_token(self, scopes, **kwargs):
logger.debug("ServicePrincipalCredential.get_token: scopes=%r, kwargs=%r", scopes, kwargs)
return self.acquire_token_for_client(scopes=scopes, **kwargs)
52 changes: 42 additions & 10 deletions src/azure-cli-core/azure/cli/core/tests/test_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -1939,9 +1939,14 @@ def test_find_using_specific_tenant(self, _get_authorization_code_mock, mock_aut
self.assertEqual(all_subscriptions[0].tenant_id, token_tenant)
self.assertEqual(all_subscriptions[0].home_tenant_id, home_tenant)

@mock.patch('azure.cli.core._profile.CredsCache.retrieve_token_for_user', autospec=True)
@mock.patch('msal.ConfidentialClientApplication.acquire_token_for_client', autospec=True)
@mock.patch('azure.cli.core._profile.CredsCache.retrieve_cred_for_service_principal', autospec=True)
@mock.patch('msal.ClientApplication.acquire_token_by_refresh_token', autospec=True)
def test_get_msal_token(self, mock_acquire_token, mock_retrieve_token_for_user):
@mock.patch('azure.cli.core._profile.CredsCache.retrieve_token_for_user', autospec=True)
@mock.patch('azure.cli.core._profile.Profile.get_subscription', autospec=True)
def test_get_msal_token(self, get_subscription_mock, retrieve_token_for_user_mock,
acquire_token_by_refresh_token_mock, retrieve_cred_for_service_principal_mock,
acquire_token_for_client_mock):
"""
This is added only for vmssh feature.
It is a temporary solution and will deprecate after MSAL adopted completely.
Expand All @@ -1950,23 +1955,50 @@ def test_get_msal_token(self, mock_acquire_token, mock_retrieve_token_for_user):
storage_mock = {'subscriptions': None}
profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False)

consolidated = profile._normalize_properties(self.user1, [self.subscription1], False)
profile._set_subscriptions(consolidated)
msal_result = {
'token_type': 'ssh-cert',
'scope': 'https://pas.windows.net/CheckMyAccess/Linux/user_impersonation https://pas.windows.net/CheckMyAccess/Linux/.default',
'expires_in': 3599,
'ext_expires_in': 3599,
'access_token': 'fake_cert'
}

some_token_type = 'Bearer'
mock_retrieve_token_for_user.return_value = (some_token_type, TestProfile.raw_token1, TestProfile.token_entry1)
mock_acquire_token.return_value = {
'access_token': 'fake_access_token'
# User
get_subscription_mock.return_value = {
'tenantId': self.tenant_id,
'user': {
'name': self.user1,
'type': 'user'
},
}
scopes = ["https://pas.windows.net/CheckMyAccess/Linux/user_impersonation"]

retrieve_token_for_user_mock.return_value = ('Bearer', self.raw_token1, self.token_entry1)
acquire_token_by_refresh_token_mock.return_value = msal_result

scopes = ["https://pas.windows.net/CheckMyAccess/Linux/.default"]
data = {
"token_type": "ssh-cert",
"req_cnf": "fake_jwk",
"key_id": "fake_id"
}
username, access_token = profile.get_msal_token(scopes, data)
self.assertEqual(username, self.user1)
self.assertEqual(access_token, 'fake_access_token')
self.assertEqual(access_token, 'fake_cert')

# Service Principal
sp_id = '610a3200-0000-0000-0000-000000000000'
get_subscription_mock.return_value = {
'tenantId': self.tenant_id,
'user': {
'name': sp_id,
'type': 'servicePrincipal'
},
}
retrieve_cred_for_service_principal_mock.return_value = "some_secret"
acquire_token_for_client_mock.return_value = msal_result
username, access_token = profile.get_msal_token(scopes, data)
self.assertEqual(username, sp_id)
self.assertEqual(access_token, 'fake_cert')


class FileHandleStub(object): # pylint: disable=too-few-public-methods
Expand Down
6 changes: 6 additions & 0 deletions src/azure-cli-core/azure/cli/core/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,12 @@ def test_scopes_to_resource(self):
self.assertEqual(scopes_to_resource(('https://managedhsm.azure.com/.default',)),
'https://managedhsm.azure.com')

# VM SSH
self.assertEqual(scopes_to_resource(["https://pas.windows.net/CheckMyAccess/Linux/.default"]),
'https://pas.windows.net/CheckMyAccess/Linux')
self.assertEqual(scopes_to_resource(["https://pas.windows.net/CheckMyAccess/Linux/user_impersonation"]),
'https://pas.windows.net/CheckMyAccess/Linux')

def test_resource_to_scopes(self):
from azure.cli.core.util import resource_to_scopes
# resource converted to a scopes list
Expand Down
8 changes: 6 additions & 2 deletions src/azure-cli-core/azure/cli/core/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1222,8 +1222,12 @@ def scopes_to_resource(scopes):
:rtype: str
"""
scope = scopes[0]
if scope.endswith("/.default"):
scope = scope[:-len("/.default")]

suffixes = ['/.default', '/user_impersonation']

for s in suffixes:
if scope.endswith(s):
return scope[:-len(s)]

return scope

Expand Down