diff --git a/src/azure-cli/azure/cli/command_modules/keyvault/_params.py b/src/azure-cli/azure/cli/command_modules/keyvault/_params.py index 8da92e09dbd..3d0083604b7 100644 --- a/src/azure-cli/azure/cli/command_modules/keyvault/_params.py +++ b/src/azure-cli/azure/cli/command_modules/keyvault/_params.py @@ -77,6 +77,10 @@ class CLIJsonWebKeyCurveName(str, Enum): p_384 = "P-384" #: The NIST P-384 elliptic curve, AKA SECG curve SECP384R1. p_521 = "P-521" #: The NIST P-521 elliptic curve, AKA SECG curve SECP521R1. + class CLISecurityDomainOperation(str, Enum): + download = "download" #: Download operation + upload = "upload" #: Upload operation + (KeyPermissions, SecretPermissions, CertificatePermissions, StoragePermissions, NetworkRuleBypassOptions, NetworkRuleAction) = self.get_models( 'KeyPermissions', 'SecretPermissions', 'CertificatePermissions', 'StoragePermissions', @@ -467,6 +471,7 @@ class CLIJsonWebKeyCurveName(str, Enum): c.argument('hsm_name', hsm_url_type, required=False, help='Name of the HSM. Can be omitted if --id is specified.') c.extra('identifier', options_list=['--id'], validator=validate_vault_or_hsm, help='Id of the HSM.') + c.ignore('vault_base_url') with self.argument_context('keyvault security-domain init-recovery') as c: c.argument('sd_exchange_key', help='Local file path to store the exported key.') @@ -488,7 +493,6 @@ class CLIJsonWebKeyCurveName(str, Enum): help='Path to a file where the JSON blob returned by this command is stored.') c.argument('sd_quorum', type=int, help='The minimum number of shares required to decrypt the security domain ' 'for recovery.') - c.ignore('vault_base_url') with self.argument_context('keyvault security-domain wait') as c: c.argument('hsm_name', hsm_url_type, help='Name of the HSM. Can be omitted if --id is specified.', @@ -496,6 +500,9 @@ class CLIJsonWebKeyCurveName(str, Enum): c.argument('identifier', options_list=['--id'], validator=validate_vault_or_hsm, help='Id of the HSM.') c.argument('resource_group_name', options_list=['--resource-group', '-g'], help='Proceed only if HSM belongs to the specified resource group.') + c.argument('target_operation', arg_type=get_enum_type(CLISecurityDomainOperation), + help='Target operation that needs waiting.') + c.ignore('vault_base_url') # endregion # region keyvault backup/restore diff --git a/src/azure-cli/azure/cli/command_modules/keyvault/commands.py b/src/azure-cli/azure/cli/command_modules/keyvault/commands.py index 0c11d80a4ab..877ad3a81a3 100644 --- a/src/azure-cli/azure/cli/command_modules/keyvault/commands.py +++ b/src/azure-cli/azure/cli/command_modules/keyvault/commands.py @@ -139,7 +139,7 @@ def load_command_table(self, _): is_preview=True) as g: g.keyvault_custom('init-recovery', 'security_domain_init_recovery') g.keyvault_custom('upload', 'security_domain_upload', supports_no_wait=True) - g.keyvault_custom('download', 'security_domain_download') + g.keyvault_custom('download', 'security_domain_download', supports_no_wait=True) g.keyvault_custom('wait', '_wait_security_domain_operation') with self.command_group('keyvault key', data_entity.command_type) as g: diff --git a/src/azure-cli/azure/cli/command_modules/keyvault/custom.py b/src/azure-cli/azure/cli/command_modules/keyvault/custom.py index e0bfee612a8..bae72cb9f7f 100644 --- a/src/azure-cli/azure/cli/command_modules/keyvault/custom.py +++ b/src/azure-cli/azure/cli/command_modules/keyvault/custom.py @@ -2173,11 +2173,11 @@ def full_restore(cmd, client, token, folder_to_restore, storage_resource_uri=Non # region security domain def security_domain_init_recovery(client, hsm_name, sd_exchange_key, - identifier=None): # pylint: disable=unused-argument + identifier=None, vault_base_url=None): # pylint: disable=unused-argument if os.path.exists(sd_exchange_key): raise CLIError("File named '{}' already exists.".format(sd_exchange_key)) - ret = client.transfer_key(vault_base_url=hsm_name) + ret = client.transfer_key(vault_base_url=hsm_name or vault_base_url) exchange_key = json.loads(json.loads(ret)['transfer_key']) def get_x5c_as_pem(): @@ -2204,14 +2204,22 @@ def get_x5c_as_pem(): raise ex -def _wait_security_domain_operation(client, hsm_name, identifier=None): # pylint: disable=unused-argument +def _wait_security_domain_operation(client, hsm_name, target_operation='upload', + identifier=None, vault_base_url=None): # pylint: disable=unused-argument retries = 0 max_retries = 30 wait_second = 5 while retries < max_retries: try: - ret = client.upload_pending(vault_base_url=hsm_name) - if ret and getattr(ret, 'status', None) in ['Succeeded', 'Failed']: + ret = None + if target_operation == 'upload': + ret = client.upload_pending(vault_base_url=hsm_name or vault_base_url) + elif target_operation == 'download': + ret = client.download_pending(vault_base_url=hsm_name or vault_base_url) + + # v7.2-preview and v7.2 will change the upload operation from Sync to Async + # due to service defects, it returns 'Succeeded' before the change and 'Success' after the change + if ret and getattr(ret, 'status', None) in ['Succeeded', 'Success', 'Failed']: return ret except: # pylint: disable=bare-except pass @@ -2312,7 +2320,7 @@ def _security_domain_gen_blob(sd_exchange_key, share_arrays, enc_data, required) def security_domain_upload(cmd, client, hsm_name, sd_file, sd_exchange_key, sd_wrapping_keys, passwords=None, - identifier=None, no_wait=False): # pylint: disable=unused-argument + identifier=None, vault_base_url=None, no_wait=False): # pylint: disable=unused-argument resource_paths = [sd_file, sd_exchange_key] for p in resource_paths: if not os.path.exists(p): @@ -2351,19 +2359,21 @@ def security_domain_upload(cmd, client, hsm_name, sd_file, sd_exchange_key, sd_w ) SecurityDomainObject = cmd.get_models('SecurityDomainObject', resource_type=ResourceType.DATA_PRIVATE_KEYVAULT) security_domain = SecurityDomainObject(value=restore_blob_value) - retval = client.upload(vault_base_url=hsm_name, security_domain=security_domain) + retval = client.upload(vault_base_url=hsm_name or vault_base_url, security_domain=security_domain) if no_wait: return retval - new_retval = _wait_security_domain_operation(client, hsm_name) + wait_second = 5 + time.sleep(wait_second) + new_retval = _wait_security_domain_operation(client, hsm_name, 'upload', vault_base_url=vault_base_url) if new_retval: return new_retval return retval def security_domain_download(cmd, client, hsm_name, sd_wrapping_keys, security_domain_file, sd_quorum, - identifier=None, vault_base_url=None): # pylint: disable=unused-argument + identifier=None, vault_base_url=None, no_wait=False): # pylint: disable=unused-argument if os.path.exists(security_domain_file): raise CLIError("File named '{}' already exists.".format(security_domain_file)) @@ -2406,15 +2416,30 @@ def security_domain_download(cmd, client, hsm_name, sd_wrapping_keys, security_d certificates.append(sd_jwk) + # save security-domain backup value to local file + def _save_to_local_file(file_path, security_domain): + try: + with open(file_path, 'w') as f: + f.write(security_domain.value) + except Exception as ex: # pylint: disable=bare-except + if os.path.isfile(file_path): + os.remove(file_path) + from azure.cli.core.azclierror import FileOperationError + raise FileOperationError(str(ex)) + ret = client.download( vault_base_url=hsm_name or vault_base_url, certificates=CertificateSet(certificates=certificates, required=sd_quorum) ) - try: - with open(security_domain_file, 'w') as f: - f.write(ret.value) - except: # pylint: disable=bare-except - if os.path.isfile(security_domain_file): - os.remove(security_domain_file) + if not no_wait: + wait_second = 5 + time.sleep(wait_second) + polling_ret = _wait_security_domain_operation(client, hsm_name, 'download', vault_base_url=vault_base_url) + # Due to service defect, status could be 'Success' or 'Succeeded' when it succeeded + if polling_ret and getattr(polling_ret, 'status', None) != 'Failed': + _save_to_local_file(security_domain_file, ret) + return polling_ret + + _save_to_local_file(security_domain_file, ret) # endregion diff --git a/src/azure-cli/azure/cli/command_modules/keyvault/tests/latest/test_keyvault_commands.py b/src/azure-cli/azure/cli/command_modules/keyvault/tests/latest/test_keyvault_commands.py index 8ee5f4de4b1..1493aa50b41 100644 --- a/src/azure-cli/azure/cli/command_modules/keyvault/tests/latest/test_keyvault_commands.py +++ b/src/azure-cli/azure/cli/command_modules/keyvault/tests/latest/test_keyvault_commands.py @@ -371,7 +371,7 @@ def test_keyvault_hsm_security_domain(self): # download SD self.cmd('az keyvault security-domain download --hsm-name {hsm_name} --security-domain-file "{sdfile}" ' - '--sd-quorum 2 --sd-wrapping-keys "{cer1_path}" "{cer2_path}" "{cer3_path}"') + '--sd-quorum 2 --sd-wrapping-keys "{cer1_path}" "{cer2_path}" "{cer3_path}" --no-wait') # delete the HSM self.cmd('az keyvault delete --hsm-name {hsm_name}') diff --git a/src/azure-cli/azure/cli/command_modules/keyvault/vendored_sdks/azure_keyvault_t1/v7_2/key_vault_client.py b/src/azure-cli/azure/cli/command_modules/keyvault/vendored_sdks/azure_keyvault_t1/v7_2/key_vault_client.py index 7ad7a4b3e91..a2496ee5236 100644 --- a/src/azure-cli/azure/cli/command_modules/keyvault/vendored_sdks/azure_keyvault_t1/v7_2/key_vault_client.py +++ b/src/azure-cli/azure/cli/command_modules/keyvault/vendored_sdks/azure_keyvault_t1/v7_2/key_vault_client.py @@ -11,6 +11,7 @@ # pylint: skip-file # flake8: noqa import json +import time from msrest.service_client import SDKClient from msrest import Serializer, Deserializer @@ -114,12 +115,16 @@ def download(self, vault_base_url, certificates, custom_headers=None, raw=False, response = self._client.send( request, header_parameters, body_content, stream=False, **operation_config) - if response.status_code not in [200]: + # v7.2-preview and v7.2 will introduce a breaking change to make the operation change from Sync to Async + # 200: for compatability of response before the change (Sync Operation) + # 202: for the support of new response after the change (Async Operation) + if response.status_code not in [200, 202]: raise models.KeyVaultErrorException(self._deserialize, response) deserialized = None - if response.status_code == 200: + # for both old response and new response + if response.status_code in [200, 202]: deserialized = self._deserialize('SecurityDomainObject', response) if raw: @@ -129,6 +134,58 @@ def download(self, vault_base_url, certificates, custom_headers=None, raw=False, return deserialized download.metadata = {'url': '/securitydomain/download'} + def download_pending(self, vault_base_url, custom_headers=None, raw=False, **operation_config): + """Get Security domain upload operation status. + :param vault_base_url: The vault name, for example https://myvault.vault.azure.net. + :type vault_base_url: str + :keyword callable cls: A custom type or function that will be passed the direct response + :return: SecurityDomainOperationStatus, or the result of cls(response) + :rtype: ~key_vault_client.models.SecurityDomainOperationStatus + :raises: ~azure.core.exceptions.HttpResponseError + """ + + # Construct URL + url = self.upload_pending.metadata['url'] + path_format_arguments = { + 'vaultBaseUrl': self._serialize.url("vault_base_url", vault_base_url, 'str', skip_quote=True) + } + url = self._client.format_url(url, **path_format_arguments) + + # Construct parameters + query_parameters = {} + query_parameters['api-version'] = self._serialize.query("self.api_version", self.api_version, 'str') + + # Construct headers + header_parameters = {} + header_parameters['Content-Type'] = 'application/json; charset=utf-8' + if self.config.generate_client_request_id: + header_parameters['x-ms-client-request-id'] = str(uuid.uuid1()) + if custom_headers: + header_parameters.update(custom_headers) + if self.config.accept_language is not None: + header_parameters['accept-language'] = self._serialize.header("self.config.accept_language", + self.config.accept_language, 'str') + + # Construct and send request + request = self._client.get(url, query_parameters) + response = self._client.send( + request, header_parameters, stream=False, **operation_config) + + if response.status_code not in [200]: + raise models.KeyVaultErrorException(self._deserialize, response) + + deserialized = None + + if response.status_code == 200: + deserialized = self._deserialize('SecurityDomainOperationStatus', response) + + if raw: + client_raw_response = ClientRawResponse(deserialized, response) + return client_raw_response + + return deserialized + download_pending.metadata = {'url': '/securitydomain/download/pending'} + def transfer_key(self, vault_base_url, custom_headers=None, raw=False, **operation_config): """Retrieve security domain transfer key. :param vault_base_url: The vault name, for example https://myvault.vault.azure.net.