Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 63 additions & 18 deletions src/azure-cli-core/azure/cli/core/_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,13 @@ def login(self,

def login_with_managed_identity(self, identity_id=None, client_id=None, object_id=None, resource_id=None,
allow_no_subscriptions=None):
if _on_azure_arc():
return self.login_with_managed_identity_azure_arc(
identity_id=identity_id, allow_no_subscriptions=allow_no_subscriptions)
if _use_msal_managed_identity(self.cli_ctx):
if identity_id:
raise CLIError('--username is not supported by MSAL managed identity. '
'Use --client-id, --object-id or --resource-id instead.')
return self.login_with_managed_identity_msal(
client_id=client_id, object_id=object_id, resource_id=resource_id,
allow_no_subscriptions=allow_no_subscriptions)

import jwt
from azure.mgmt.core.tools import is_valid_resource_id
Expand Down Expand Up @@ -304,22 +308,23 @@ def login_with_managed_identity(self, identity_id=None, client_id=None, object_i
self._set_subscriptions(consolidated)
return deepcopy(consolidated)

def login_with_managed_identity_azure_arc(self, identity_id=None, allow_no_subscriptions=None):
def login_with_managed_identity_msal(self, client_id=None, object_id=None, resource_id=None,
allow_no_subscriptions=None):
import jwt
identity_type = MsiAccountTypes.system_assigned
from .auth.msal_credentials import ManagedIdentityCredential
from .auth.constants import ACCESS_TOKEN

cred = ManagedIdentityCredential()
identity_id_type, identity_id_value = MsiAccountTypes.parse_ids(
client_id=client_id, object_id=object_id, resource_id=resource_id)
cred = MsiAccountTypes.msal_credential_factory(identity_id_type, identity_id_value)
token = cred.acquire_token(self._arm_scope)[ACCESS_TOKEN]
logger.info('Managed identity: token was retrieved. Now trying to initialize local accounts...')
decode = jwt.decode(token, algorithms=['RS256'], options={"verify_signature": False})
tenant = decode['tid']

subscription_finder = SubscriptionFinder(self.cli_ctx)
subscriptions = subscription_finder.find_using_specific_tenant(tenant, cred)
base_name = ('{}-{}'.format(identity_type, identity_id) if identity_id else identity_type)
user = _USER_ASSIGNED_IDENTITY if identity_id else _SYSTEM_ASSIGNED_IDENTITY
base_name = ('{}-{}'.format(identity_id_type, identity_id_value) if identity_id_value else identity_id_type)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous variable name identity_type is not accurate. Identity type means systemAssignedIdentity or userAssignedIdentity.

user = _USER_ASSIGNED_IDENTITY if identity_id_value else _SYSTEM_ASSIGNED_IDENTITY
if not subscriptions:
if allow_no_subscriptions:
subscriptions = self._build_tenant_level_accounts([tenant])
Expand Down Expand Up @@ -399,10 +404,10 @@ def get_login_credentials(self, subscription_id=None, aux_subscriptions=None, au

elif managed_identity_type:
# managed identity
if _on_azure_arc():
from .auth.msal_credentials import ManagedIdentityCredential
if _use_msal_managed_identity(self.cli_ctx):
# The credential must be wrapped by CredentialAdaptor so that it can work with Track 1 SDKs.
sdk_cred = CredentialAdaptor(ManagedIdentityCredential())
cred = MsiAccountTypes.msal_credential_factory(managed_identity_type, managed_identity_id)
sdk_cred = CredentialAdaptor(cred)
else:
# The resource is merely used by msrestazure to get the first access token.
# It is not actually used in an API invocation.
Expand Down Expand Up @@ -432,7 +437,8 @@ def get_login_credentials(self, subscription_id=None, aux_subscriptions=None, au
str(account[_SUBSCRIPTION_ID]),
str(account[_TENANT_ID]))

def get_raw_token(self, resource=None, scopes=None, subscription=None, tenant=None):
def get_raw_token(self, resource=None, scopes=None, subscription=None, tenant=None, credential_out=None):
# credential_out is only used by unit tests to inspect the credential. Do not use it!
# Convert resource to scopes
if resource and not scopes:
from .auth.util import resource_to_scopes
Expand Down Expand Up @@ -460,9 +466,11 @@ def get_raw_token(self, resource=None, scopes=None, subscription=None, tenant=No
# managed identity
if tenant:
raise CLIError("Tenant shouldn't be specified for managed identity account")
if _on_azure_arc():
from .auth.msal_credentials import ManagedIdentityCredential
sdk_cred = CredentialAdaptor(ManagedIdentityCredential())
if _use_msal_managed_identity(self.cli_ctx):
cred = MsiAccountTypes.msal_credential_factory(managed_identity_type, managed_identity_id)
if credential_out:
credential_out['credential'] = cred
sdk_cred = CredentialAdaptor(cred)
else:
from .auth.util import scopes_to_resource
sdk_cred = MsiAccountTypes.msi_auth_factory(managed_identity_type, managed_identity_id,
Expand Down Expand Up @@ -810,6 +818,41 @@ def msi_auth_factory(cli_account_name, identity, resource):
return MSIAuthenticationWrapper(resource=resource, msi_res_id=identity)
raise ValueError("unrecognized msi account name '{}'".format(cli_account_name))

@staticmethod
def parse_ids(client_id=None, object_id=None, resource_id=None):
id_arg_count = len([arg for arg in (client_id, object_id, resource_id) if arg])
if id_arg_count > 1:
raise CLIError('Usage error: Provide only one of --client-id, --object-id, --resource-id.')

id_type = None
id_value = None
if id_arg_count == 0:
id_type = MsiAccountTypes.system_assigned
id_value = None
elif client_id:
id_type = MsiAccountTypes.user_assigned_client_id
id_value = client_id
elif object_id:
id_type = MsiAccountTypes.user_assigned_object_id
id_value = object_id
elif resource_id:
id_type = MsiAccountTypes.user_assigned_resource_id
id_value = resource_id
return id_type, id_value

@staticmethod
def msal_credential_factory(id_type, id_value):
from azure.cli.core.auth.msal_credentials import ManagedIdentityCredential
if id_type == MsiAccountTypes.system_assigned:
return ManagedIdentityCredential()
if id_type == MsiAccountTypes.user_assigned_client_id:
return ManagedIdentityCredential(client_id=id_value)
if id_type == MsiAccountTypes.user_assigned_object_id:
return ManagedIdentityCredential(object_id=id_value)
if id_type == MsiAccountTypes.user_assigned_resource_id:
return ManagedIdentityCredential(resource_id=id_value)
raise ValueError("Unrecognized managed identity ID type '{}'".format(id_type))


class SubscriptionFinder:
# An ARM client. It finds subscriptions for a user or service principal. It shouldn't do any
Expand Down Expand Up @@ -976,7 +1019,9 @@ def _create_identity_instance(cli_ctx, authority, tenant_id=None, client_id=None
instance_discovery=instance_discovery)


def _on_azure_arc():
def _use_msal_managed_identity(cli_ctx):
# This indicates an Azure Arc-enabled server
from msal.managed_identity import get_managed_identity_source, AZURE_ARC
return get_managed_identity_source() == AZURE_ARC
# PREVIEW: Use core.use_msal_managed_identity=true to enable managed identity authentication with MSAL
use_msal_managed_identity = cli_ctx.config.getboolean('core', 'use_msal_managed_identity', fallback=False)
return use_msal_managed_identity or get_managed_identity_source() == AZURE_ARC
11 changes: 8 additions & 3 deletions src/azure-cli-core/azure/cli/core/auth/msal_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from knack.log import get_logger
from knack.util import CLIError
from msal import (PublicClientApplication, ConfidentialClientApplication,
ManagedIdentityClient, SystemAssignedManagedIdentity)
ManagedIdentityClient, SystemAssignedManagedIdentity, UserAssignedManagedIdentity)

from .constants import AZURE_CLI_CLIENT_ID
from .util import check_result
Expand Down Expand Up @@ -131,9 +131,14 @@ class ManagedIdentityCredential: # pylint: disable=too-few-public-methods
Currently, only Azure Arc's system-assigned managed identity is supported.
"""

def __init__(self):
def __init__(self, client_id=None, resource_id=None, object_id=None):
import requests
self._msal_client = ManagedIdentityClient(SystemAssignedManagedIdentity(), http_client=requests.Session())
if client_id or resource_id or object_id:
managed_identity = UserAssignedManagedIdentity(
client_id=client_id, resource_id=resource_id, object_id=object_id)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note (unblocking)
If there is not exactly one ID being used, a ManagedIdentityError exception will be thrown here. A fyi, in case you would want to catch it and provide your more suitable error message.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already check this by ourselves:

id_arg_count = len([arg for arg in (client_id, object_id, resource_id) if arg])
if id_arg_count > 1:
raise CLIError('Usage error: Provide only one of --client-id, --object-id, --resource-id.')

else:
managed_identity = SystemAssignedManagedIdentity()
self._msal_client = ManagedIdentityClient(managed_identity, http_client=requests.Session())

def acquire_token(self, scopes, **kwargs):
logger.debug("ManagedIdentityCredential.acquire_token: scopes=%r, kwargs=%r", scopes, kwargs)
Expand Down
Loading
Loading