Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 7 additions & 11 deletions src/azure-cli-core/azure/cli/core/_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from azure.cli.core._environment import get_config_dir
from azure.cli.core._session import ACCOUNT
from azure.cli.core.util import get_file_json, in_cloud_console, open_page_in_browser, can_launch_browser,\
is_windows, is_wsl, scopes_to_resource
is_windows, is_wsl, scopes_to_resource, resource_to_scopes
from azure.cli.core.cloud import get_active_cloud, set_cloud_subscription

logger = get_logger(__name__)
Expand Down Expand Up @@ -574,11 +574,7 @@ def get_login_credentials(self, resource=None, subscription_id=None, aux_subscri
"Please run `az login` with a user account or a service principal.")

if identity_type is None:
def _retrieve_token(sdk_resource=None):
# When called by
# - Track 1 SDK, use `resource` specified by CLI
# - Track 2 SDK, use `sdk_resource` specified by SDK and ignore `resource` specified by CLI
token_resource = sdk_resource or resource
def _retrieve_token(token_resource):
logger.debug("Retrieving token from ADAL for resource %r", token_resource)

if in_cloud_console() and account[_USER_ENTITY].get(_CLOUD_SHELL_ID):
Expand All @@ -591,8 +587,7 @@ def _retrieve_token(sdk_resource=None):
account[_TENANT_ID],
use_cert_sn_issuer)

def _retrieve_tokens_from_external_tenants(sdk_resource=None):
token_resource = sdk_resource or resource
def _retrieve_tokens_from_external_tenants(token_resource):
logger.debug("Retrieving token from ADAL for external tenants and resource %r", token_resource)

external_tokens = []
Expand All @@ -607,7 +602,8 @@ def _retrieve_tokens_from_external_tenants(sdk_resource=None):

from azure.cli.core.adal_authentication import AdalAuthentication
auth_object = AdalAuthentication(_retrieve_token,
_retrieve_tokens_from_external_tenants if external_tenants_info else None)
_retrieve_tokens_from_external_tenants if external_tenants_info else None,
resource=resource)
else:
if self._msi_creds is None:
self._msi_creds = MsiAccountTypes.msi_auth_factory(identity_type, identity_id, resource)
Expand Down Expand Up @@ -675,7 +671,7 @@ def get_msal_token(self, scopes, data):
raise CLIError("Unknown identity type {}".format(identity_type))

if 'error' in result:
from azure.cli.core.adal_authentication import aad_error_handler
from azure.cli.core.auth.util import aad_error_handler
aad_error_handler(result)

return username_or_sp_id, result["access_token"]
Expand Down Expand Up @@ -721,7 +717,7 @@ def get_raw_token(self, resource=None, subscription=None, tenant=None):
use_cert_sn_issuer)
except adal.AdalError as ex:
from azure.cli.core.adal_authentication import adal_error_handler
adal_error_handler(ex)
adal_error_handler(ex, scopes=resource_to_scopes(resource))
return (creds,
None if tenant else str(account[_SUBSCRIPTION_ID]),
str(tenant if tenant else account[_TENANT_ID]))
Expand Down
36 changes: 15 additions & 21 deletions src/azure-cli-core/azure/cli/core/adal_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from msrest.authentication import Authentication
from msrestazure.azure_active_directory import MSIAuthentication
from azure.core.credentials import AccessToken
from azure.cli.core.util import in_cloud_console, scopes_to_resource
from azure.cli.core.util import in_cloud_console, scopes_to_resource, resource_to_scopes

from knack.util import CLIError
from knack.log import get_logger
Expand All @@ -19,7 +19,7 @@

class AdalAuthentication(Authentication): # pylint: disable=too-few-public-methods

def __init__(self, token_retriever, external_tenant_token_retriever=None):
def __init__(self, token_retriever, external_tenant_token_retriever=None, resource=None):
# DO NOT call _token_retriever from outside azure-cli-core. It is only available for user or
# Service Principal credential (AdalAuthentication), but not for Managed Identity credential
# (MSIAuthenticationWrapper).
Expand All @@ -28,24 +28,31 @@ def __init__(self, token_retriever, external_tenant_token_retriever=None):
# - AdalAuthentication.get_token, which is designed for Track 2 SDKs
self._token_retriever = token_retriever
self._external_tenant_token_retriever = external_tenant_token_retriever
self._resource = resource

def _get_token(self, sdk_resource=None):
"""
:param sdk_resource: `resource` converted from Track 2 SDK's `scopes`
"""

# When called by
# - Track 1 SDK, use `resource` specified by CLI
# - Track 2 SDK, use `sdk_resource` specified by SDK and ignore `resource` specified by CLI
token_resource = sdk_resource or self._resource

external_tenant_tokens = None
try:
scheme, token, token_entry = self._token_retriever(sdk_resource)
scheme, token, token_entry = self._token_retriever(token_resource)
if self._external_tenant_token_retriever:
external_tenant_tokens = self._external_tenant_token_retriever(sdk_resource)
external_tenant_tokens = self._external_tenant_token_retriever(token_resource)
except CLIError as err:
if in_cloud_console():
AdalAuthentication._log_hostname()
raise err
except adal.AdalError as err:
if in_cloud_console():
AdalAuthentication._log_hostname()
adal_error_handler(err)
adal_error_handler(err, scopes=resource_to_scopes(token_resource))
except requests.exceptions.SSLError as err:
from .util import SSLERROR_TEMPLATE
raise CLIError(SSLERROR_TEMPLATE.format(str(err)))
Expand Down Expand Up @@ -236,24 +243,11 @@ def _timestamp(dt):
return dt.timestamp()


def aad_error_handler(error: dict):
""" Handle the error from AAD server returned by ADAL or MSAL. """
login_message = ("To re-authenticate, please {}. If the problem persists, "
"please contact your tenant administrator."
.format("refresh Azure Portal" if in_cloud_console() else "run `az login`"))

# https://docs.microsoft.com/en-us/azure/active-directory/develop/reference-aadsts-error-codes
# Search for an error code at https://login.microsoftonline.com/error
msg = error.get('error_description')

from azure.cli.core.azclierror import AuthenticationError
raise AuthenticationError(msg, login_message)


def adal_error_handler(err: adal.AdalError):
def adal_error_handler(err: adal.AdalError, **kwargs):
""" Handle AdalError. """
try:
aad_error_handler(err.error_response)
from azure.cli.core.auth.util import aad_error_handler
aad_error_handler(err.error_response, **kwargs)
except AttributeError:
# In case of AdalError created as
# AdalError('More than one token matches the criteria. The result is ambiguous.')
Expand Down
45 changes: 45 additions & 0 deletions src/azure-cli-core/azure/cli/core/auth/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------


def aad_error_handler(error, **kwargs):
""" Handle the error from AAD server returned by ADAL or MSAL. """

# https://docs.microsoft.com/en-us/azure/active-directory/develop/reference-aadsts-error-codes
# Search for an error code at https://login.microsoftonline.com/error
msg = error.get('error_description')
login_message = _generate_login_message(**kwargs)

from azure.cli.core.azclierror import AuthenticationError
raise AuthenticationError(msg, recommendation=login_message)


def _generate_login_command(scopes=None):
login_command = ['az login']

if scopes:
login_command.append('--scope {}'.format(' '.join(scopes)))

return ' '.join(login_command)


def _generate_login_message(**kwargs):
from azure.cli.core.util import in_cloud_console
login_command = _generate_login_command(**kwargs)

msg = "To re-authenticate, please {}" .format(
"refresh Azure Portal." if in_cloud_console() else "run:\n{}".format(login_command))

return msg


def decode_access_token(access_token):
# Decode the access token. We can do the same with https://jwt.ms
from msal.oauth2cli.oidc import decode_part
import json

# Access token consists of headers.claims.signature. Decode the claim part
decoded_str = decode_part(access_token.split('.')[1])
return json.loads(decoded_str)
15 changes: 8 additions & 7 deletions src/azure-cli-core/azure/cli/core/tests/test_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def setUpClass(cls):
'e-lOym1sH5iOcxfIjXF0Tp2y0f3zM7qCq8Cp1ZxEwz6xYIgByoxjErNXrOME5Ld1WizcsaWxTXpwxJn_'
'Q8U2g9kXHrbYFeY2gJxF_hnfLvNKxUKUBnftmyYxZwKi0GDS0BvdJnJnsqSRSpxUx__Ra9QJkG1IaDzj'
'ZcSZPHK45T6ohK9Hk9ktZo0crVl7Tmw')
cls.arm_resource = 'https://management.core.windows.net/'

def test_normalize(self):
cli = DummyCli()
Expand Down Expand Up @@ -551,7 +552,7 @@ def test_get_login_credentials(self, mock_get_token, mock_read_cred_file):
self.assertEqual(subscription_id, test_subscription_id)

# verify the cred._tokenRetriever is a working lambda
token_type, token = cred._token_retriever()
token_type, token = cred._token_retriever(self.arm_resource)
self.assertEqual(token, self.raw_token1)
self.assertEqual(some_token_type, token_type)
mock_get_token.assert_called_once_with(mock.ANY, self.user1, test_tenant_id,
Expand Down Expand Up @@ -595,11 +596,11 @@ def test_get_login_credentials_aux_subscriptions(self, mock_get_token, mock_read
self.assertEqual(subscription_id, test_subscription_id)

# verify the cred._tokenRetriever is a working lambda
token_type, token = cred._token_retriever()
token_type, token = cred._token_retriever(self.arm_resource)
self.assertEqual(token, self.raw_token1)
self.assertEqual(some_token_type, token_type)

token2 = cred._external_tenant_token_retriever()
token2 = cred._external_tenant_token_retriever(self.arm_resource)
self.assertEqual(len(token2), 1)
self.assertEqual(token2[0][1], raw_token2)

Expand Down Expand Up @@ -642,11 +643,11 @@ def test_get_login_credentials_aux_tenants(self, mock_get_token, mock_read_cred_
self.assertEqual(subscription_id, test_subscription_id)

# verify the cred._tokenRetriever is a working lambda
token_type, token = cred._token_retriever()
token_type, token = cred._token_retriever(self.arm_resource)
self.assertEqual(token, self.raw_token1)
self.assertEqual(some_token_type, token_type)

token2 = cred._external_tenant_token_retriever()
token2 = cred._external_tenant_token_retriever(self.arm_resource)
self.assertEqual(len(token2), 1)
self.assertEqual(token2[0][1], raw_token2)

Expand Down Expand Up @@ -949,7 +950,7 @@ def test_get_login_credentials_for_graph_client(self, mock_get_token, mock_read_
# action
cred, _, tenant_id = profile.get_login_credentials(
resource=cli.cloud.endpoints.active_directory_graph_resource_id)
_, _ = cred._token_retriever()
_, _ = cred._token_retriever('https://graph.windows.net/')
# verify
mock_get_token.assert_called_once_with(mock.ANY, self.user1, self.tenant_id,
'https://graph.windows.net/')
Expand All @@ -971,7 +972,7 @@ def test_get_login_credentials_for_data_lake_client(self, mock_get_token, mock_r
# action
cred, _, tenant_id = profile.get_login_credentials(
resource=cli.cloud.endpoints.active_directory_data_lake_resource_id)
_, _ = cred._token_retriever()
_, _ = cred._token_retriever('https://datalake.azure.net/')
# verify
mock_get_token.assert_called_once_with(mock.ANY, self.user1, self.tenant_id,
'https://datalake.azure.net/')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from knack.util import CLIError


@unittest.skip("Out of maintenance")
class TestProfile(unittest.TestCase):

@classmethod
Expand Down
3 changes: 3 additions & 0 deletions src/azure-cli-testsdk/azure/cli/testsdk/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,9 @@ def __init__(self, method_name):
self.kwargs = {}
self.test_resources_count = 0

def setUp(self):
patch_main_exception_handler(self)
Comment on lines +217 to +218
Copy link
Member Author

@jiasli jiasli Aug 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

azure.cli.core.util.handle_exception is patched for ScenarioTest. We have to patch azure.cli.core.util.handle_exception for LiveScenarioTest as well. Otherwise, AuthenticationError can't be checked in tests.


def cmd(self, command, checks=None, expect_failure=False):
command = self._apply_kwargs(command)
return execute(self.cli_ctx, command, expect_failure=expect_failure).assert_with_checks(checks)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------

from azure.cli.core.azclierror import AuthenticationError
from azure.cli.testsdk import LiveScenarioTest
from azure.cli.core.auth.util import decode_access_token

ARM_URL = "https://eastus2euap.management.azure.com/" # ARM canary
ARM_MAX_RETRY = 30
ARM_RETRY_INTERVAL = 10


class ConditionalAccessScenarioTest(LiveScenarioTest):

def setUp(self):
super().setUp()
# Clear MSAL cache to avoid unexpected tokens from cache
self.cmd('az account clear')

def test_conditional_access_mfa(self):
"""
This test should be run using a user account that
- doesn't require MFA for ARM
- requires MFA for data-plane resource

The result ATs are checked per https://docs.microsoft.com/en-us/azure/active-directory/develop/access-tokens

Following claims are checked:
- aud (Audience): https://tools.ietf.org/html/rfc7519#section-4.1.3
- amr (Authentication Method Reference): https://tools.ietf.org/html/rfc8176
"""

resource = 'https://pas.windows.net/CheckMyAccess/Linux'
scope = resource + '/.default'

self.kwargs['scope'] = scope
self.kwargs['resource'] = resource

# region non-MFA session

# Login to ARM (MFA not required)
# In the browser, if the user already exists, make sure to logout first and re-login to clear browser cache
self.cmd('az login')

# Getting ARM AT and check claims
result = self.cmd('az account get-access-token').get_output_in_json()
decoded = decode_access_token(result['accessToken'])
assert decoded['aud'] == self.cli_ctx.cloud.endpoints.active_directory_resource_id
assert decoded['amr'] == ['pwd']

# Getting data-plane AT with ARM RT (step-up) fails
with self.assertRaises(AuthenticationError) as cm:
self.cmd('az account get-access-token --resource {resource}')

# Check re-login recommendation
re_login_command = 'az login --scope {scope}'.format(**self.kwargs)
assert 'AADSTS50076' in cm.exception.error_msg
assert re_login_command in cm.exception.recommendations[0]

# endregion

# region MFA session

# Re-login with data-plane scope (MFA required)
# Getting ARM AT with data-plane RT (step-down) succeeds
self.cmd(re_login_command)

# Getting ARM AT and check claims
result = self.cmd('az account get-access-token').get_output_in_json()
decoded = decode_access_token(result['accessToken'])
assert decoded['aud'] == self.cli_ctx.cloud.endpoints.active_directory_resource_id
assert decoded['amr'] == ['pwd']

# Getting data-plane AT and check claims
result = self.cmd('az account get-access-token --resource {resource}').get_output_in_json()
decoded = decode_access_token(result['accessToken'])
assert decoded['aud'] in scope
assert decoded['amr'] == ['pwd', 'mfa']

# endregion