diff --git a/src/azure-cli-core/azure/cli/core/auth/identity.py b/src/azure-cli-core/azure/cli/core/auth/identity.py index e2132159509..f165b3407e5 100644 --- a/src/azure-cli-core/azure/cli/core/auth/identity.py +++ b/src/azure-cli-core/azure/cli/core/auth/identity.py @@ -11,17 +11,22 @@ from azure.cli.core._environment import get_config_dir from knack.log import get_logger from knack.util import CLIError -from msal import PublicClientApplication +from msal import PublicClientApplication, ConfidentialClientApplication -# Service principal entry properties -from .msal_authentication import _CLIENT_ID, _TENANT, _CLIENT_SECRET, _CERTIFICATE, _CLIENT_ASSERTION, \ - _USE_CERT_SN_ISSUER -from .msal_authentication import UserCredential, ServicePrincipalCredential +from .msal_credentials import UserCredential, ServicePrincipalCredential from .persistence import load_persisted_token_cache, file_extensions, load_secret_store from .util import check_result AZURE_CLI_CLIENT_ID = '04b07795-8ddb-461a-bbee-02f9e1bf7b46' +# Service principal entry properties. Names are taken from OAuth 2.0 client credentials flow parameters: +# https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-client-creds-grant-flow +_TENANT = 'tenant' +_CLIENT_ID = 'client_id' +_CLIENT_SECRET = 'client_secret' +_CERTIFICATE = 'certificate' +_CLIENT_ASSERTION = 'client_assertion' +_USE_CERT_SN_ISSUER = 'use_cert_sn_issuer' # For environment credential AZURE_AUTHORITY_HOST = "AZURE_AUTHORITY_HOST" @@ -29,6 +34,8 @@ AZURE_CLIENT_ID = "AZURE_CLIENT_ID" AZURE_CLIENT_SECRET = "AZURE_CLIENT_SECRET" +FEDERATED_IDENTITY = "FEDERATED_IDENTITY" + WAM_PROMPT = ( "Select the account you want to log in with. " "For more information on login with Azure CLI, see https://go.microsoft.com/fwlink/?linkid=2271136") @@ -187,10 +194,9 @@ def login_with_service_principal(self, client_id, credential, scopes): `credential` is a dict returned by ServicePrincipalAuth.build_credential """ sp_auth = ServicePrincipalAuth.build_from_credential(self.tenant_id, client_id, credential) - - # This cred means SDK credential object - cred = ServicePrincipalCredential(sp_auth, **self._msal_app_kwargs) - result = cred.acquire_token_for_client(scopes) + client_credential = sp_auth.get_msal_client_credential() + cca = ConfidentialClientApplication(client_id, client_credential=client_credential, **self._msal_app_kwargs) + result = cca.acquire_token_for_client(scopes) check_result(result) # Only persist the service principal after a successful login @@ -236,31 +242,45 @@ def get_user_credential(self, username): def get_service_principal_credential(self, client_id): entry = self._service_principal_store.load_entry(client_id, self.tenant_id) - sp_auth = ServicePrincipalAuth(entry) - return ServicePrincipalCredential(sp_auth, **self._msal_app_kwargs) + client_credential = ServicePrincipalAuth(entry).get_msal_client_credential() + return ServicePrincipalCredential(client_id, client_credential, **self._msal_app_kwargs) def get_managed_identity_credential(self, client_id=None): raise NotImplementedError class ServicePrincipalAuth: - def __init__(self, entry): + # Initialize all attributes first, so that we don't need to call getattr to check their existence + self.client_id = None + self.tenant = None + # secret + self.client_secret = None + # certificate + self.certificate = None + self.use_cert_sn_issuer = None + # federated identity credential + self.client_assertion = None + + # Internal attributes for certificate + self._certificate_string = None + self._thumbprint = None + self._public_certificate = None + self.__dict__.update(entry) - if _CERTIFICATE in entry: + if self.certificate: from OpenSSL.crypto import load_certificate, FILETYPE_PEM, Error - self.public_certificate = None try: with open(self.certificate, 'r') as file_reader: - self.certificate_string = file_reader.read() - cert = load_certificate(FILETYPE_PEM, self.certificate_string) - self.thumbprint = cert.digest("sha1").decode().replace(':', '') + self._certificate_string = file_reader.read() + cert = load_certificate(FILETYPE_PEM, self._certificate_string) + self._thumbprint = cert.digest("sha1").decode().replace(':', '') if entry.get(_USE_CERT_SN_ISSUER): # low-tech but safe parsing based on # https://github.com/libressl-portable/openbsd/blob/master/src/lib/libcrypto/pem/pem.h match = re.search(r'-----BEGIN CERTIFICATE-----(?P[^-]+)-----END CERTIFICATE-----', - self.certificate_string, re.I) + self._certificate_string, re.I) self.public_certificate = match.group() except (UnicodeDecodeError, Error) as ex: raise CLIError('Invalid certificate, please use a valid PEM file. Error detail: {}'.format(ex)) @@ -298,7 +318,41 @@ def build_credential(cls, secret_or_certificate=None, client_assertion=None, use def get_entry_to_persist(self): persisted_keys = [_CLIENT_ID, _TENANT, _CLIENT_SECRET, _CERTIFICATE, _USE_CERT_SN_ISSUER, _CLIENT_ASSERTION] - return {k: v for k, v in self.__dict__.items() if k in persisted_keys} + # Only persist certain attributes whose values are not None + return {k: v for k, v in self.__dict__.items() if k in persisted_keys and v} + + def get_msal_client_credential(self): + client_credential = None + + # client_secret + # "your client secret" + if self.client_secret: + client_credential = self.client_secret + + # certificate + # { + # "private_key": "...-----BEGIN PRIVATE KEY-----... in PEM format", + # "thumbprint": "A1B2C3D4E5F6...", + # "public_certificate": "...-----BEGIN CERTIFICATE-----...", + # } + if self.certificate: + client_credential = { + "private_key": self._certificate_string, + "thumbprint": self._thumbprint + } + if self._public_certificate: + client_credential['public_certificate'] = self._public_certificate + + # client_assertion + # { + # "client_assertion": "...a JWT with claims aud, exp, iss, jti, nbf, and sub..." + # } + if self.client_assertion: + client_credential = { + 'client_assertion': get_id_token_on_github if self.client_assertion == FEDERATED_IDENTITY + else self.client_assertion} + + return client_credential class ServicePrincipalStore: @@ -405,3 +459,22 @@ def get_environment_credential(): getenv(AZURE_TENANT_ID)) credentials = ServicePrincipalCredential(sp_auth, authority=authority) return credentials + + +def get_id_token_on_github(): + import os + from urllib.parse import quote + import requests + token = os.environ['ACTIONS_ID_TOKEN_REQUEST_TOKEN'] + url = os.environ['ACTIONS_ID_TOKEN_REQUEST_URL'] + encodedAudience = quote('api://AzureADTokenExchange') + url = f'{url}&audience={encodedAudience}' + headers = { + 'Authorization': f'bearer {token}', + 'Accept': 'application/json; api-version=2.0', + 'Content-Type': 'application/json' + } + result = requests.get(url, headers=headers) + id_token = result.json()['value'] + logger.warning('Got ID token: %s', id_token) + return id_token diff --git a/src/azure-cli-core/azure/cli/core/auth/msal_authentication.py b/src/azure-cli-core/azure/cli/core/auth/msal_credentials.py similarity index 67% rename from src/azure-cli-core/azure/cli/core/auth/msal_authentication.py rename to src/azure-cli-core/azure/cli/core/auth/msal_credentials.py index b7b43ae32ba..041b009bcf4 100644 --- a/src/azure-cli-core/azure/cli/core/auth/msal_authentication.py +++ b/src/azure-cli-core/azure/cli/core/auth/msal_credentials.py @@ -22,19 +22,10 @@ from .util import check_result, build_sdk_access_token -# OAuth 2.0 client credentials flow parameter -# https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-client-creds-grant-flow -_TENANT = 'tenant' -_CLIENT_ID = 'client_id' -_CLIENT_SECRET = 'client_secret' -_CERTIFICATE = 'certificate' -_CLIENT_ASSERTION = 'client_assertion' -_USE_CERT_SN_ISSUER = 'use_cert_sn_issuer' - logger = get_logger(__name__) -class UserCredential(PublicClientApplication): +class UserCredential: # pylint: disable=too-few-public-methods def __init__(self, client_id, username, **kwargs): """User credential implementing get_token interface. @@ -42,12 +33,12 @@ def __init__(self, client_id, username, **kwargs): :param client_id: Client ID of the CLI. :param username: The username for user credential. """ - super().__init__(client_id, **kwargs) + self._msal_app = PublicClientApplication(client_id, **kwargs) # Make sure username is specified, otherwise MSAL returns all accounts assert username, "username must be specified, got {!r}".format(username) - accounts = self.get_accounts(username) + accounts = self._msal_app.get_accounts(username) # Usernames are usually unique. We are collecting corner cases to better understand its behavior. if len(accounts) > 1: @@ -65,8 +56,9 @@ def get_token(self, *scopes, claims=None, **kwargs): if claims: logger.warning('Acquiring new access token silently for tenant %s with claims challenge: %s', - self.authority.tenant, claims) - result = self.acquire_token_silent_with_error(list(scopes), self._account, claims_challenge=claims, **kwargs) + self._msal_app.authority.tenant, claims) + result = self._msal_app.acquire_token_silent_with_error(list(scopes), self._account, claims_challenge=claims, + **kwargs) from azure.cli.core.azclierror import AuthenticationError try: @@ -82,13 +74,14 @@ def get_token(self, *scopes, claims=None, **kwargs): logger.warning(ex) logger.warning("\nThe default web browser has been opened at %s for scope '%s'. " "Please continue the login in the web browser.", - self.authority.authorization_endpoint, ' '.join(scopes)) + self._msal_app.authority.authorization_endpoint, ' '.join(scopes)) from .util import read_response_templates success_template, error_template = read_response_templates() - result = self.acquire_token_interactive( - list(scopes), login_hint=self._account['username'], port=8400 if self.authority.is_adfs else None, + result = self._msal_app.acquire_token_interactive( + list(scopes), login_hint=self._account['username'], + port=8400 if self._msal_app.authority.is_adfs else None, success_template=success_template, error_template=error_template, **kwargs) check_result(result) @@ -99,42 +92,18 @@ def get_token(self, *scopes, claims=None, **kwargs): return build_sdk_access_token(result) -class ServicePrincipalCredential(ConfidentialClientApplication): +class ServicePrincipalCredential: # pylint: disable=too-few-public-methods - def __init__(self, service_principal_auth, **kwargs): + def __init__(self, client_id, client_credential, **kwargs): """Service principal credential implementing get_token interface. :param service_principal_auth: An instance of ServicePrincipalAuth. """ - client_credential = None - - # client_secret - client_secret = getattr(service_principal_auth, _CLIENT_SECRET, None) - if client_secret: - client_credential = client_secret - - # certificate - certificate = getattr(service_principal_auth, _CERTIFICATE, None) - if certificate: - client_credential = { - "private_key": getattr(service_principal_auth, 'certificate_string'), - "thumbprint": getattr(service_principal_auth, 'thumbprint') - } - public_certificate = getattr(service_principal_auth, 'public_certificate', None) - if public_certificate: - client_credential['public_certificate'] = public_certificate - - # client_assertion - client_assertion = getattr(service_principal_auth, _CLIENT_ASSERTION, None) - if client_assertion: - client_credential = {'client_assertion': client_assertion} - - super().__init__(service_principal_auth.client_id, client_credential=client_credential, **kwargs) + self._msal_app = ConfidentialClientApplication(client_id, client_credential, **kwargs) def get_token(self, *scopes, **kwargs): logger.debug("ServicePrincipalCredential.get_token: scopes=%r, kwargs=%r", scopes, kwargs) - scopes = list(scopes) - result = self.acquire_token_for_client(scopes, **kwargs) + result = self._msal_app.acquire_token_for_client(list(scopes), **kwargs) check_result(result) return build_sdk_access_token(result) diff --git a/src/azure-cli/azure/cli/command_modules/profile/__init__.py b/src/azure-cli/azure/cli/command_modules/profile/__init__.py index dc8bba6e8f5..e1bde901c9c 100644 --- a/src/azure-cli/azure/cli/command_modules/profile/__init__.py +++ b/src/azure-cli/azure/cli/command_modules/profile/__init__.py @@ -5,7 +5,7 @@ from azure.cli.core import AzCommandsLoader from azure.cli.core.commands import CliCommandType -from azure.cli.core.commands.parameters import get_enum_type +from azure.cli.core.commands.parameters import get_enum_type, get_three_state_flag from azure.cli.command_modules.profile._format import transform_account_list import azure.cli.command_modules.profile._help # pylint: disable=unused-import @@ -58,6 +58,8 @@ def load_arguments(self, command): c.argument('use_cert_sn_issuer', action='store_true', help='used with a service principal configured with Subject Name and Issuer Authentication in order to support automatic certificate rolls') c.argument('scopes', options_list=['--scope'], nargs='+', help='Used in the /authorize request. It can cover only one static resource.') c.argument('client_assertion', options_list=['--federated-token'], help='Federated token that can be used for OIDC token exchange.') + c.argument('federated_identity', options_list=['--federated-identity'], arg_type=get_three_state_flag(), + help='Use federated identity credential.') with self.argument_context('logout') as c: c.argument('username', help='account user, if missing, logout the current active account') diff --git a/src/azure-cli/azure/cli/command_modules/profile/custom.py b/src/azure-cli/azure/cli/command_modules/profile/custom.py index 0a049f1e267..350a88adf77 100644 --- a/src/azure-cli/azure/cli/command_modules/profile/custom.py +++ b/src/azure-cli/azure/cli/command_modules/profile/custom.py @@ -116,7 +116,8 @@ def account_clear(cmd): # pylint: disable=inconsistent-return-statements, too-many-branches def login(cmd, username=None, password=None, service_principal=None, tenant=None, allow_no_subscriptions=False, - identity=False, use_device_code=False, use_cert_sn_issuer=None, scopes=None, client_assertion=None): + identity=False, use_device_code=False, use_cert_sn_issuer=None, scopes=None, client_assertion=None, + federated_identity=None): """Log in to access Azure subscriptions""" # quick argument usage check @@ -128,6 +129,9 @@ def login(cmd, username=None, password=None, service_principal=None, tenant=None raise CLIError("usage error: '--use-sn-issuer' is only applicable with a service principal") if service_principal and not username: raise CLIError('usage error: --service-principal --username NAME --password SECRET --tenant TENANT') + if client_assertion and federated_identity: + raise CLIError('usage error: Only one of --federated-token and --federated-identity can be specified') + if username and not service_principal and not identity: logger.warning(USERNAME_PASSWORD_DEPRECATION_WARNING) @@ -143,7 +147,7 @@ def login(cmd, username=None, password=None, service_principal=None, tenant=None logger.warning(_CLOUD_CONSOLE_LOGIN_WARNING) if username: - if not (password or client_assertion): + if not (password or client_assertion or federated_identity): try: password = prompt_pass('Password: ') except NoTTYException: @@ -153,7 +157,10 @@ def login(cmd, username=None, password=None, service_principal=None, tenant=None if service_principal: from azure.cli.core.auth.identity import ServicePrincipalAuth - password = ServicePrincipalAuth.build_credential(password, client_assertion, use_cert_sn_issuer) + password = ServicePrincipalAuth.build_credential( + secret_or_certificate=password, + client_assertion='FEDERATED_IDENTITY' if federated_identity else client_assertion, + use_cert_sn_issuer=use_cert_sn_issuer) login_experience_v2 = cmd.cli_ctx.config.getboolean('core', 'login_experience_v2', fallback=True) # Send login_experience_v2 config to telemetry