diff --git a/src/ssh/azext_ssh/constants.py b/src/ssh/azext_ssh/constants.py index 9be7ef75418..4f26d291392 100644 --- a/src/ssh/azext_ssh/constants.py +++ b/src/ssh/azext_ssh/constants.py @@ -8,3 +8,4 @@ CLIENT_PROXY_STORAGE_URL = "https://sshproxysa.blob.core.windows.net" CLEANUP_TOTAL_TIME_LIMIT_IN_SECONDS = 120 CLEANUP_TIME_INTERVAL_IN_SECONDS = 10 +DEFAULT_KEY_TEMPDIR_NAME = "azclisshkeys" diff --git a/src/ssh/azext_ssh/custom.py b/src/ssh/azext_ssh/custom.py index e93a6ac9353..02e53cc5644 100644 --- a/src/ssh/azext_ssh/custom.py +++ b/src/ssh/azext_ssh/custom.py @@ -29,7 +29,7 @@ def ssh_vm(cmd, resource_group_name=None, vm_name=None, resource_id=None, ssh_ip _assert_args(resource_group_name, vm_name, ssh_ip, resource_id, cert_file, local_user) do_ssh_op = _decide_op_call(cmd, resource_group_name, vm_name, resource_id, ssh_ip, None, None, - ssh_client_path, ssh_args, delete_privkey) + ssh_client_path, ssh_args, delete_privkey, local_user) do_ssh_op(cmd, ssh_ip, public_key_file, private_key_file, local_user, cert_file, port, use_private_ip) @@ -40,7 +40,7 @@ def ssh_config(cmd, config_path, resource_group_name=None, vm_name=None, ssh_ip= _assert_args(resource_group_name, vm_name, ssh_ip, resource_id, cert_file, local_user) do_ssh_op = _decide_op_call(cmd, resource_group_name, vm_name, resource_id, ssh_ip, config_path, overwrite, - None, None, None) + None, None, None, None) do_ssh_op(cmd, ssh_ip, public_key_file, private_key_file, local_user, cert_file, port, use_private_ip) @@ -184,10 +184,16 @@ def _assert_args(resource_group, vm_name, ssh_ip, resource_id, cert_file, userna def _check_or_create_public_private_files(public_key_file, private_key_file): # If nothing is passed in create a temporary directory with a ephemeral keypair if not public_key_file and not private_key_file: - temp_dir = tempfile.mkdtemp(prefix="aadsshcert") + temp_dir = os.path.join(tempfile.gettempdir(), consts.DEFAULT_KEY_TEMPDIR_NAME) public_key_file = os.path.join(temp_dir, "id_rsa.pub") private_key_file = os.path.join(temp_dir, "id_rsa") - ssh_utils.create_ssh_keyfile(private_key_file) + if not os.path.isdir(temp_dir): + new_temp_dir = tempfile.mkdtemp() + os.rename(new_temp_dir, os.path.join(os.path.dirname(new_temp_dir), consts.DEFAULT_KEY_TEMPDIR_NAME)) + if not os.path.isfile(public_key_file) or not os.path.isfile(private_key_file): + file_utils.delete_file(public_key_file, f"Couldn't delete existing public key {public_key_file}. ") + file_utils.delete_file(private_key_file, f"Couldn't delete existing private key {private_key_file}. ") + ssh_utils.create_ssh_keyfile(private_key_file) if not public_key_file: if private_key_file: @@ -311,7 +317,7 @@ def _arc_list_access_details(cmd, resource_group, vm_name): def _decide_op_call(cmd, resource_group_name, vm_name, resource_id, ssh_ip, config_path, overwrite, - ssh_client_path, ssh_args, delete_privkey): + ssh_client_path, ssh_args, delete_privkey, local_user): # If the user provides an IP address the target will be treated as an Azure VM even if it is an # Arc Server. Which just means that the Connectivity Proxy won't be used to establish connection. @@ -350,8 +356,11 @@ def _decide_op_call(cmd, resource_group_name, vm_name, resource_id, ssh_ip, conf op_call = functools.partial(ssh_utils.write_ssh_config, config_path=config_path, overwrite=overwrite, resource_group=resource_group_name) else: + delete_cert = False + if not local_user: + delete_cert = True op_call = functools.partial(ssh_utils.start_ssh_connection, ssh_client_path=ssh_client_path, ssh_args=ssh_args, - delete_privkey=delete_privkey) + delete_privkey=delete_privkey, delete_cert=delete_cert) do_ssh_op = functools.partial(_do_ssh_op, resource_group_name=resource_group_name, vm_name=vm_name, is_arc=is_arc_server, op_call=op_call) diff --git a/src/ssh/azext_ssh/file_utils.py b/src/ssh/azext_ssh/file_utils.py index f3ecebf8b7f..01d94e5f7fb 100644 --- a/src/ssh/azext_ssh/file_utils.py +++ b/src/ssh/azext_ssh/file_utils.py @@ -27,13 +27,14 @@ def mkdir_p(path): def delete_file(file_path, message, warning=False): - try: - os.remove(file_path) - except Exception as e: - if warning: - logger.warning(message) - else: - raise azclierror.FileOperationError(message + "Error: " + str(e)) from e + if os.path.isfile(file_path): + try: + os.remove(file_path) + except Exception as e: + if warning: + logger.warning(message) + else: + raise azclierror.FileOperationError(message + "Error: " + str(e)) from e def create_directory(file_path, error_message): diff --git a/src/ssh/azext_ssh/ssh_utils.py b/src/ssh/azext_ssh/ssh_utils.py index d165267af88..d79681bff1e 100644 --- a/src/ssh/azext_ssh/ssh_utils.py +++ b/src/ssh/azext_ssh/ssh_utils.py @@ -22,7 +22,7 @@ def start_ssh_connection(relay_info, proxy_path, vm_name, ip, username, cert_file, private_key_file, port, - is_arc, ssh_client_path, ssh_args, delete_privkey): + is_arc, ssh_client_path, ssh_args, delete_privkey, delete_cert): if not ssh_client_path: ssh_client_path = _get_ssh_path() @@ -33,11 +33,17 @@ def start_ssh_connection(relay_info, proxy_path, vm_name, ip, username, cert_fil ssh_client_log_file_arg = [] # delete_privkey is only true for injected commands in the portal one click ssh experience - if delete_privkey and (cert_file or private_key_file): + # if delete_privkey is true, delete certificate and private key + # if delete_cert is true, delete only certificate + # I admit that these names are confusing. We Should probably rename them. + if (delete_privkey and (cert_file or private_key_file)) or (delete_cert and cert_file): + + #Is there a point to do this if the os is Windows? It seems that it doesn't work, + #It's just extra work if '-E' in ssh_arg_list: - # This condition should rarely be true index = ssh_arg_list.index('-E') log_file = ssh_arg_list[index + 1] + # if the user provides their own log file, we should probably not overwrite it else: if cert_file: log_dir = os.path.dirname(cert_file) @@ -47,6 +53,7 @@ def start_ssh_connection(relay_info, proxy_path, vm_name, ip, username, cert_fil log_file = os.path.join(log_dir, log_file_name) ssh_client_log_file_arg = ['-E', log_file] + # This might be a problem, because the user won't get the verbosity printed. if '-v' not in ssh_arg_list and '-vv' not in ssh_arg_list and '-vvv' not in ssh_arg_list: ssh_client_log_file_arg = ssh_client_log_file_arg + ['-v'] @@ -68,10 +75,10 @@ def start_ssh_connection(relay_info, proxy_path, vm_name, ip, username, cert_fil # If delete_privkey flag is true, we will try to clean the private key file and the certificate file # once the connection has been established. If it's not possible to open the log file, we default to # waiting for about 2 minutes once the ssh process starts before cleaning up the files. - if delete_privkey and (cert_file or private_key_file): - if os.path.isfile(log_file): - file_utils.delete_file(log_file, f"Couldn't delete existing log file {log_file}", True) - cleanup_process = mp.Process(target=_do_cleanup, args=(private_key_file, cert_file, log_file)) + if (delete_privkey and (cert_file or private_key_file)) or (delete_cert and cert_file): + #We shouldn't delete the file if it's provided by the user + file_utils.delete_file(log_file, f"Couldn't delete existing log file {log_file}", True) + cleanup_process = mp.Process(target=_do_cleanup, args=(private_key_file, cert_file, delete_privkey, delete_cert, log_file)) cleanup_process.start() logger.debug("Running ssh command %s", ' '.join(command)) @@ -79,11 +86,12 @@ def start_ssh_connection(relay_info, proxy_path, vm_name, ip, username, cert_fil # If the cleanup process is still alive once the ssh process is terminated, we terminate it and make # sure the private key and certificate are deleted. - if delete_privkey and (cert_file or private_key_file): + # Should we also delete the log file if it's not provided by the user? Or will this not allow the user to ever look at their client side logs? + if (delete_privkey and (cert_file or private_key_file)) or (delete_cert and cert_file): if cleanup_process.is_alive(): cleanup_process.terminate() time.sleep(1) - _do_cleanup(private_key_file, cert_file) + _do_cleanup(private_key_file, cert_file, delete_privkey, delete_cert) def create_ssh_keyfile(private_key_file): @@ -118,6 +126,7 @@ def write_ssh_config(relay_info, proxy_path, vm_name, ip, username, cert_file, private_key_file, port, is_arc, config_path, overwrite, resource_group): common_lines = [] + common_lines.append("Host " + resource_group + "-" + vm_name) common_lines.append("\tUser " + username) if cert_file: common_lines.append("\tCertificateFile " + cert_file) @@ -131,13 +140,17 @@ def write_ssh_config(relay_info, proxy_path, vm_name, ip, username, elif private_key_file: relay_info_dir = os.path.dirname(private_key_file) else: - relay_info_dir = tempfile.mkdtemp(prefix="ssharcrelayinfo") - relay_info_path = os.path.join(relay_info_dir, "relay_info") + relay_info_dir = os.path.join(tempfile.gettempdir(), const.DEFAULT_KEY_TEMPDIR_NAME) + if not os.path.isdir(relay_info_dir): + new_dir = tempfile.mkdtemp() + os.rename(new_dir, os.path.join(os.path.dirname(new_dir), const.DEFAULT_KEY_TEMPDIR_NAME)) + + relay_info_filename = "relay_info_" + vm_name + "_" + resource_group + relay_info_path = os.path.join(relay_info_dir, relay_info_filename) file_utils.write_to_file(relay_info_path, 'w', relay_info, f"Couldn't write relay information to file {relay_info_path}", 'utf-8') oschmod.set_mode(relay_info_path, stat.S_IRUSR) - lines.append("Host " + vm_name) lines = lines + common_lines if port: lines.append("\tProxyCommand " + proxy_path + " " + "-r " + relay_info_path + " " + "-p " + port) @@ -145,9 +158,8 @@ def write_ssh_config(relay_info, proxy_path, vm_name, ip, username, lines.append("\tProxyCommand " + proxy_path + " " + "-r " + relay_info_path) else: if resource_group and vm_name: - lines.append("Host " + resource_group + "-" + vm_name) - lines.append("\tHostName " + ip) lines = lines + common_lines + lines.append("\tHostName " + ip) if port: lines.append("\tPort " + port) @@ -208,8 +220,10 @@ def _build_args(cert_file, private_key_file, port): return private_key + certificate + port_arg -def _do_cleanup(private_key_file, cert_file, log_file=None): - if os.environ.get("AZUREPS_HOST_ENVIRONMENT") != "cloud-shell/1.0": +def _do_cleanup(private_key_file, cert_file, delete_privkey, delete_cert, log_file=None): + print(delete_privkey) + print(delete_cert) + if delete_privkey and os.environ.get("AZUREPS_HOST_ENVIRONMENT") != "cloud-shell/1.0": raise azclierror.BadRequestError("Can't delete private key file. " "The --delete-private-key flag set to True, " "but this is not an Azure Cloud Shell session.") @@ -229,8 +243,8 @@ def _do_cleanup(private_key_file, cert_file, log_file=None): if t1 < const.CLEANUP_TOTAL_TIME_LIMIT_IN_SECONDS: time.sleep(const.CLEANUP_TOTAL_TIME_LIMIT_IN_SECONDS - t1) - if private_key_file and os.path.isfile(private_key_file): + if delete_privkey and private_key_file: file_utils.delete_file(private_key_file, f"Failed to delete private key file '{private_key_file}'. ") - if cert_file and os.path.isfile(cert_file): + if (delete_privkey or delete_cert) and cert_file: file_utils.delete_file(cert_file, f"Failed to delete certificate file '{cert_file}'. ") diff --git a/src/ssh/azext_ssh/tests/latest/test_custom.py b/src/ssh/azext_ssh/tests/latest/test_custom.py index 49580239e07..bb331e5894b 100644 --- a/src/ssh/azext_ssh/tests/latest/test_custom.py +++ b/src/ssh/azext_ssh/tests/latest/test_custom.py @@ -230,7 +230,7 @@ def test_assert_args_cert_with_no_user(self): def test_assert_args_invalid_cert_filepath(self, mock_is_file): mock_is_file.return_value = False self.assertRaises(azclierror.FileOperationError, custom._assert_args, 'rg', 'vm', None, None, 'cert_path', 'username') - + ''' @mock.patch('azext_ssh.ssh_utils.create_ssh_keyfile') @mock.patch('tempfile.mkdtemp') @mock.patch('os.path.isfile') @@ -255,7 +255,7 @@ def test_check_or_create_public_private_files_defaults(self, mock_join, mock_isf mock_create.assert_has_calls([ mock.call('/tmp/aadtemp/id_rsa') ]) - + ''' @mock.patch('os.path.isfile') @mock.patch('os.path.join') def test_check_or_create_public_private_files_no_public(self, mock_join, mock_isfile):