From 46bcf9552d32e58c9ed0ea1e5eb7f8ce057de80e Mon Sep 17 00:00:00 2001 From: Vivian Thiebaut Date: Tue, 16 Aug 2022 13:59:43 -0400 Subject: [PATCH 1/9] Read stderr during ssh execution. Print only error messages as they appear --- src/ssh/HISTORY.md | 4 + src/ssh/azext_ssh/custom.py | 3 + src/ssh/azext_ssh/rdp_utils.py | 3 - src/ssh/azext_ssh/ssh_utils.py | 224 ++++++++++++--------------------- 4 files changed, 89 insertions(+), 145 deletions(-) diff --git a/src/ssh/HISTORY.md b/src/ssh/HISTORY.md index 4ecef1cb840..94dc6492a53 100644 --- a/src/ssh/HISTORY.md +++ b/src/ssh/HISTORY.md @@ -1,5 +1,9 @@ Release History =============== +1.1.3 +----- +* Fix bug where extension prints ssh banners after the connection is closed. + 1.1.2 ----- * Remove dependency to cryptography (Az CLI core alredy has cryptography) diff --git a/src/ssh/azext_ssh/custom.py b/src/ssh/azext_ssh/custom.py index e396fcefcaa..a72f7d745d7 100644 --- a/src/ssh/azext_ssh/custom.py +++ b/src/ssh/azext_ssh/custom.py @@ -43,6 +43,9 @@ def ssh_vm(cmd, resource_group_name=None, vm_name=None, ssh_ip=None, public_key_ # include openssh client logs to --debug output to make it easier to users to debug connection issued. if '--debug' in cmd.cli_ctx.data['safe_params'] and set(['-v', '-vv', '-vvv']).isdisjoint(ssh_args): ssh_args = ['-vvv'] if not ssh_args else ['-vvv'] + ssh_args + + if '-E' in ssh_args: + raise azclierror.ArgumentUsageError("The -E SSH agrgument is not supported.") _assert_args(resource_group_name, vm_name, ssh_ip, resource_type, cert_file, local_user) diff --git a/src/ssh/azext_ssh/rdp_utils.py b/src/ssh/azext_ssh/rdp_utils.py index 63a6ff3c460..33f2ed14e18 100644 --- a/src/ssh/azext_ssh/rdp_utils.py +++ b/src/ssh/azext_ssh/rdp_utils.py @@ -107,9 +107,6 @@ def start_ssh_tunnel(op_info): else: op_info.ssh_args = ['-v'] + op_info.ssh_args - if '-E' in op_info.ssh_args: - raise azclierror.BadRequestError("Can't use -E ssh parameter when using --rdp") - command = [ssh_utils.get_ssh_client_path('ssh', op_info.ssh_client_folder), op_info.get_host()] command = command + op_info.build_args() + op_info.ssh_args diff --git a/src/ssh/azext_ssh/ssh_utils.py b/src/ssh/azext_ssh/ssh_utils.py index 1986a604431..4a462f8fc03 100644 --- a/src/ssh/azext_ssh/ssh_utils.py +++ b/src/ssh/azext_ssh/ssh_utils.py @@ -10,6 +10,7 @@ import datetime import re import colorama +import sys from knack import log from azure.cli.core import azclierror @@ -27,13 +28,18 @@ def start_ssh_connection(op_info, delete_keys, delete_cert): try: # Initialize these so that if something fails in the try block before these # are initialized, then the finally block won't fail. - cleanup_process = None - log_file = None + connection_status = None ssh_arg_list = [] if op_info.ssh_args: ssh_arg_list = op_info.ssh_args + + print_ssh_logs = False + if not set(['-v', '-vv', '-vvv']).isdisjoint(ssh_arg_list): + print_ssh_logs = True + else: + ssh_arg_list = ['-v'] + ssh_arg_list env = os.environ.copy() if op_info.is_arc(): @@ -46,9 +52,6 @@ def start_ssh_connection(op_info, delete_keys, delete_cert): # In this case, even if delete_credentials is true, there is nothing to clean-up. op_info.delete_credentials = False - log_file, ssh_arg_list, cleanup_process = _start_cleanup(op_info.cert_file, op_info.private_key_file, - op_info.public_key_file, op_info.delete_credentials, - delete_keys, delete_cert, ssh_arg_list) command = command + op_info.build_args() + ssh_arg_list connection_duration = time.time() @@ -56,16 +59,13 @@ def start_ssh_connection(op_info, delete_keys, delete_cert): # pylint: disable=subprocess-run-check try: - if set(['-v', '-vv', '-vvv']).isdisjoint(ssh_arg_list) or log_file: - connection_status = subprocess.run(command, shell=platform.system() == 'Windows', env=env, - stderr=subprocess.PIPE, encoding='utf-8') - else: - # Logs are sent to stderr. In that case, we shouldn't capture stderr. - connection_status = subprocess.run(command, shell=platform.system() == 'Windows', env=env) + ssh_process = subprocess.Popen(command, stderr=subprocess.PIPE, env=env, encoding='utf-8') except OSError as e: colorama.init() raise azclierror.BadRequestError(f"Failed to run ssh command with error: {str(e)}.", const.RECOMMENDATION_SSH_CLIENT_NOT_FOUND) + + read_ssh_logs(ssh_process, print_ssh_logs, op_info, delete_cert, delete_keys) connection_duration = (time.time() - connection_duration) / 60 ssh_connection_data = {'Context.Default.AzureCLI.SSHConnectionDurationInMinutes': connection_duration} @@ -76,8 +76,33 @@ def start_ssh_connection(op_info, delete_keys, delete_cert): finally: # Even if something fails between the creation of the credentials and the end of the ssh connection, we # want to make sure that all credentials are cleaned up, and that the clean up process is terminated. - _terminate_cleanup(delete_keys, delete_cert, op_info.delete_credentials, cleanup_process, op_info.cert_file, - op_info.private_key_file, op_info.public_key_file, log_file, connection_status) + do_cleanup(delete_keys, delete_cert, op_info.cert_file, op_info.private_key_file, op_info.public_key_file) + + +def read_ssh_logs(ssh_sub, print_ssh_logs, op_info, delete_cert, delete_keys): + log_list = [] + ssh_sucess = False + + next_line = ssh_sub.stderr.readline() + while next_line: + if "debug1:" not in next_line and \ + "debug2:" not in next_line and \ + "debug3:" not in next_line: + print(next_line, end='', file=sys.stderr) + _check_for_known_errors(next_line, delete_cert, log_list) + elif print_ssh_logs: + print(next_line, end='', file=sys.stderr) + + log_list.append(next_line) + + if "debug1: Entering interactive session." in next_line: + logger.debug("SSH Connection estalished succesfully.") + do_cleanup(delete_keys, delete_cert, op_info.cert_file, op_info.private_key_file, op_info.public_key_file) + + next_line = ssh_sub.stderr.readline() + + ssh_sub.wait() + return ssh_sucess def write_ssh_config(config_info, delete_keys, delete_cert): @@ -162,51 +187,45 @@ def get_ssh_cert_principals(cert_file, ssh_client_folder=None): return principals -def _print_error_messages_from_ssh_log(log_file, connection_status, delete_cert): - with open(log_file, 'r', encoding='utf-8') as ssh_log: - log_text = ssh_log.read() - log_lines = log_text.splitlines() - if ("debug1: Authentication succeeded" not in log_text and - not re.search("^Authenticated to .*\n", log_text, re.MULTILINE)) \ - or (connection_status and connection_status.returncode): - for line in log_lines: - if "debug1:" not in line: - print(line) - - # This connection fails when using our generated certificates. - # Only throw error if conection fails with AAD login. - if "Permission denied (publickey)." in log_text and delete_cert: - # pylint: disable=bare-except - # pylint: disable=too-many-boolean-expressions - # Check if OpenSSH client and server versions are incompatible - try: - regex = 'OpenSSH.*_([0-9]+)\\.([0-9]+)' - local_major, local_minor = re.findall(regex, log_lines[0])[0] - remote_major, remote_minor = re.findall(regex, - file_utils.get_line_that_contains("remote software version", - log_lines))[0] - local_major = int(local_major) - local_minor = int(local_minor) - remote_major = int(remote_major) - remote_minor = int(remote_minor) - except: - ssh_log.close() - return - - if (remote_major < 7 or (remote_major == 7 and remote_minor < 8)) and \ - (local_major > 8 or (local_major == 8 and local_minor >= 8)): - logger.warning("The OpenSSH server version in the target VM %d.%d is too old. " - "Version incompatible with OpenSSH client version %d.%d. " - "Refer to https://bugzilla.mindrot.org/show_bug.cgi?id=3351 for more information.", - remote_major, remote_minor, local_major, local_minor) - - elif (local_major < 7 or (local_major == 7 and local_minor < 8)) and \ - (remote_major > 8 or (remote_major == 8 and remote_minor >= 8)): - logger.warning("The OpenSSH client version %d.%d is too old. " - "Version incompatible with OpenSSH server version %d.%d in the target VM. " - "Refer to https://bugzilla.mindrot.org/show_bug.cgi?id=3351 for more information.", - local_major, local_minor, remote_major, remote_minor) - ssh_log.close() +def _check_for_known_errors(error_message, delete_cert, log_lines): + # This connection fails when using our generated certificates. + # Only throw error if conection fails with AAD login. + if "Permission denied (publickey)." in error_message and delete_cert: + # pylint: disable=bare-except + # pylint: disable=too-many-boolean-expressions + # Check if OpenSSH client and server versions are incompatible + try: + regex = 'OpenSSH.*_([0-9]+)\\.([0-9]+)' + local_major, local_minor = re.findall(regex, log_lines[0])[0] + remote_major, remote_minor = re.findall(regex, + file_utils.get_line_that_contains("remote software version", + log_lines))[0] + local_major = int(local_major) + local_minor = int(local_minor) + remote_major = int(remote_major) + remote_minor = int(remote_minor) + except: + return + + if (remote_major < 7 or (remote_major == 7 and remote_minor < 8)) and \ + (local_major > 8 or (local_major == 8 and local_minor >= 8)): + logger.warning("The OpenSSH server version in the target VM %d.%d is too old. " + "Version incompatible with OpenSSH client version %d.%d. " + "Refer to https://bugzilla.mindrot.org/show_bug.cgi?id=3351 for more information.", + remote_major, remote_minor, local_major, local_minor) + + elif (local_major < 7 or (local_major == 7 and local_minor < 8)) and \ + (remote_major > 8 or (remote_major == 8 and remote_minor >= 8)): + logger.warning("The OpenSSH client version %d.%d is too old. " + "Version incompatible with OpenSSH server version %d.%d in the target VM. " + "Refer to https://bugzilla.mindrot.org/show_bug.cgi?id=3351 for more information.", + local_major, local_minor, remote_major, remote_minor) + + regex = ("{\"level\":\"fatal\",\"msg\":\"sshproxy: error copying information from the connection: " + ".*\",\"time\":\".*\"}.*") + if re.search(regex, error_message): + logger.error("Please make sure SSH port is allowed using \"azcmagent config list\" in the target " + "Arc Server. Ensure SSHD is running on the target machine.\n") def get_ssh_client_path(ssh_command="ssh", ssh_client_folder=None): @@ -263,96 +282,17 @@ def get_ssh_client_path(ssh_command="ssh", ssh_client_folder=None): return ssh_path -def do_cleanup(delete_keys, delete_cert, cert_file, private_key, public_key, log_file=None, wait=False): - if log_file: - t0 = time.time() - match = False - while (time.time() - t0) < const.CLEANUP_TOTAL_TIME_LIMIT_IN_SECONDS and not match: - time.sleep(const.CLEANUP_TIME_INTERVAL_IN_SECONDS) - # pylint: disable=bare-except - # pylint: disable=anomalous-backslash-in-string - try: - with open(log_file, 'r', encoding='utf-8') as ssh_client_log: - log_text = ssh_client_log.read() - # The "debug1:..." message doesn't seems to exist in OpenSSH 3.9 - match = ("debug1: Authentication succeeded" in log_text or - re.search("^Authenticated to .*\n", log_text, re.MULTILINE)) - ssh_client_log.close() - except: - # If there is an exception, wait for a little bit and try again - time.sleep(const.CLEANUP_TIME_INTERVAL_IN_SECONDS) - - elif wait: - # if we are not checking the logs, but still want to wait for connection before deleting files - time.sleep(const.CLEANUP_TOTAL_TIME_LIMIT_IN_SECONDS) - +def do_cleanup(delete_keys, delete_cert, cert_file, private_key, public_key): if delete_keys and private_key: file_utils.delete_file(private_key, f"Couldn't delete private key {private_key}. ", True) if delete_keys and public_key: file_utils.delete_file(public_key, f"Couldn't delete public key {public_key}. ", True) if delete_cert and cert_file: file_utils.delete_file(cert_file, f"Couldn't delete certificate {cert_file}. ", True) - - -def _start_cleanup(cert_file, private_key_file, public_key_file, delete_credentials, delete_keys, - delete_cert, ssh_arg_list): - log_file = None - cleanup_process = None - if delete_keys or delete_cert or delete_credentials: - if '-E' not in ssh_arg_list and set(['-v', '-vv', '-vvv']).isdisjoint(ssh_arg_list): - # If the user either provides his own client log file (-E) or - # wants the client log messages to be printed to the console (-vvv/-vv/-v), - # we should not use the log files to check for connection success. - if cert_file: - log_dir = os.path.dirname(cert_file) - elif private_key_file: - log_dir = os.path.dirname(private_key_file) - log_file_name = 'ssh_client_log_' + str(os.getpid()) - log_file = os.path.join(log_dir, log_file_name) - ssh_arg_list = ['-E', log_file, '-v'] + ssh_arg_list - # Create a new process that will wait until the connection is established and then delete keys. - cleanup_process = mp.Process(target=do_cleanup, args=(delete_keys or delete_credentials, - delete_cert or delete_credentials, - cert_file, private_key_file, public_key_file, - log_file, True)) - cleanup_process.start() - - return log_file, ssh_arg_list, cleanup_process - - -def _terminate_cleanup(delete_keys, delete_cert, delete_credentials, cleanup_process, cert_file, - private_key_file, public_key_file, log_file, connection_status): - try: - if connection_status and connection_status.stderr: - if connection_status.returncode != 0: - # Check if stderr is a proxy error - regex = ("{\"level\":\"fatal\",\"msg\":\"sshproxy: error copying information from the connection: " - ".*\",\"time\":\".*\"}.*") - if re.search(regex, connection_status.stderr): - logger.error("Please make sure SSH port is allowed using \"azcmagent config list\" in the target " - "Arc Server. Ensure SSHD is running on the target machine.") - print(connection_status.stderr) - finally: - if delete_keys or delete_cert or delete_credentials: - if cleanup_process and cleanup_process.is_alive(): - cleanup_process.terminate() - # wait for process to terminate - t0 = time.time() - while cleanup_process.is_alive() and (time.time() - t0) < const.CLEANUP_AWAIT_TERMINATION_IN_SECONDS: - time.sleep(1) - - if log_file and os.path.isfile(log_file): - _print_error_messages_from_ssh_log(log_file, connection_status, delete_cert) - - # Make sure all files have been properly removed. - do_cleanup(delete_keys or delete_credentials, delete_cert or delete_credentials, - cert_file, private_key_file, public_key_file) - if log_file: - file_utils.delete_file(log_file, f"Couldn't delete temporary log file {log_file}. ", True) - if delete_keys: - # This is only true if keys were generated, so they must be in a temp folder. - temp_dir = os.path.dirname(cert_file) - file_utils.delete_folder(temp_dir, f"Couldn't delete temporary folder {temp_dir}", True) + if delete_keys: + # This is only true if keys were generated, so they must be in a temp folder. + temp_dir = os.path.dirname(cert_file) + file_utils.delete_folder(temp_dir, f"Couldn't delete temporary folder {temp_dir}", True) def _issue_config_cleanup_warning(delete_cert, delete_keys, is_arc, cert_file, relay_info_path, ssh_client_folder): From 9f8813c401f3505f612bdf793ca9f732f132ea14 Mon Sep 17 00:00:00 2001 From: Vivian Thiebaut Date: Tue, 16 Aug 2022 15:08:55 -0400 Subject: [PATCH 2/9] filter out known messages that are printed but not needed --- src/ssh/azext_ssh/ssh_utils.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/ssh/azext_ssh/ssh_utils.py b/src/ssh/azext_ssh/ssh_utils.py index 4a462f8fc03..3431b8b283e 100644 --- a/src/ssh/azext_ssh/ssh_utils.py +++ b/src/ssh/azext_ssh/ssh_utils.py @@ -88,7 +88,15 @@ def read_ssh_logs(ssh_sub, print_ssh_logs, op_info, delete_cert, delete_keys): if "debug1:" not in next_line and \ "debug2:" not in next_line and \ "debug3:" not in next_line: - print(next_line, end='', file=sys.stderr) + + # Filter out known logs that don't start with "debug", + # but that are not useful error messages or banners. + if not next_line.startswith('Authenticated to') and \ + not next_line.startswith('Transferred: sent') and \ + not next_line.startswith('Bytes per second: sent') and \ + not next_line.startswith('OpenSSH_'): + print(next_line, end='', file=sys.stderr) + _check_for_known_errors(next_line, delete_cert, log_list) elif print_ssh_logs: print(next_line, end='', file=sys.stderr) From 8c12176d5cf3be6c69f2a04dcaf74a6afaeb5db9 Mon Sep 17 00:00:00 2001 From: Vivian Thiebaut Date: Tue, 16 Aug 2022 15:39:36 -0400 Subject: [PATCH 3/9] fix tests and style --- src/ssh/azext_ssh/custom.py | 2 +- src/ssh/azext_ssh/ssh_utils.py | 58 +++++++++---------- .../azext_ssh/tests/latest/test_ssh_utils.py | 47 ++++++++------- 3 files changed, 52 insertions(+), 55 deletions(-) diff --git a/src/ssh/azext_ssh/custom.py b/src/ssh/azext_ssh/custom.py index a72f7d745d7..9bcdbfad3bd 100644 --- a/src/ssh/azext_ssh/custom.py +++ b/src/ssh/azext_ssh/custom.py @@ -43,7 +43,7 @@ def ssh_vm(cmd, resource_group_name=None, vm_name=None, ssh_ip=None, public_key_ # include openssh client logs to --debug output to make it easier to users to debug connection issued. if '--debug' in cmd.cli_ctx.data['safe_params'] and set(['-v', '-vv', '-vvv']).isdisjoint(ssh_args): ssh_args = ['-vvv'] if not ssh_args else ['-vvv'] + ssh_args - + if '-E' in ssh_args: raise azclierror.ArgumentUsageError("The -E SSH agrgument is not supported.") diff --git a/src/ssh/azext_ssh/ssh_utils.py b/src/ssh/azext_ssh/ssh_utils.py index 3431b8b283e..081ee46b0d3 100644 --- a/src/ssh/azext_ssh/ssh_utils.py +++ b/src/ssh/azext_ssh/ssh_utils.py @@ -5,12 +5,11 @@ import os import platform import subprocess -import multiprocessing as mp import time import datetime import re -import colorama import sys +import colorama from knack import log from azure.cli.core import azclierror @@ -28,13 +27,13 @@ def start_ssh_connection(op_info, delete_keys, delete_cert): try: # Initialize these so that if something fails in the try block before these # are initialized, then the finally block won't fail. - + connection_status = None ssh_arg_list = [] if op_info.ssh_args: ssh_arg_list = op_info.ssh_args - + print_ssh_logs = False if not set(['-v', '-vv', '-vvv']).isdisjoint(ssh_arg_list): print_ssh_logs = True @@ -57,15 +56,15 @@ def start_ssh_connection(op_info, delete_keys, delete_cert): connection_duration = time.time() logger.debug("Running ssh command %s", ' '.join(command)) - # pylint: disable=subprocess-run-check try: + # pylint: disable=consider-using-with ssh_process = subprocess.Popen(command, stderr=subprocess.PIPE, env=env, encoding='utf-8') except OSError as e: colorama.init() raise azclierror.BadRequestError(f"Failed to run ssh command with error: {str(e)}.", const.RECOMMENDATION_SSH_CLIENT_NOT_FOUND) - - read_ssh_logs(ssh_process, print_ssh_logs, op_info, delete_cert, delete_keys) + + _read_ssh_logs(ssh_process, print_ssh_logs, op_info, delete_cert, delete_keys) connection_duration = (time.time() - connection_duration) / 60 ssh_connection_data = {'Context.Default.AzureCLI.SSHConnectionDurationInMinutes': connection_duration} @@ -79,34 +78,34 @@ def start_ssh_connection(op_info, delete_keys, delete_cert): do_cleanup(delete_keys, delete_cert, op_info.cert_file, op_info.private_key_file, op_info.public_key_file) -def read_ssh_logs(ssh_sub, print_ssh_logs, op_info, delete_cert, delete_keys): +def _read_ssh_logs(ssh_sub, print_ssh_logs, op_info, delete_cert, delete_keys): log_list = [] ssh_sucess = False - + next_line = ssh_sub.stderr.readline() - while next_line: + while next_line: if "debug1:" not in next_line and \ "debug2:" not in next_line and \ "debug3:" not in next_line: - # Filter out known logs that don't start with "debug", + # Filter out known logs that don't start with "debug", # but that are not useful error messages or banners. if not next_line.startswith('Authenticated to') and \ not next_line.startswith('Transferred: sent') and \ not next_line.startswith('Bytes per second: sent') and \ not next_line.startswith('OpenSSH_'): print(next_line, end='', file=sys.stderr) - + _check_for_known_errors(next_line, delete_cert, log_list) - elif print_ssh_logs: - print(next_line, end='', file=sys.stderr) + elif print_ssh_logs: + print(next_line, end='', file=sys.stderr) log_list.append(next_line) - + if "debug1: Entering interactive session." in next_line: logger.debug("SSH Connection estalished succesfully.") - do_cleanup(delete_keys, delete_cert, op_info.cert_file, op_info.private_key_file, op_info.public_key_file) - + do_cleanup(delete_keys, delete_cert, op_info.cert_file, op_info.private_key_file, op_info.public_key_file) + next_line = ssh_sub.stderr.readline() ssh_sub.wait() @@ -205,9 +204,8 @@ def _check_for_known_errors(error_message, delete_cert, log_lines): try: regex = 'OpenSSH.*_([0-9]+)\\.([0-9]+)' local_major, local_minor = re.findall(regex, log_lines[0])[0] - remote_major, remote_minor = re.findall(regex, - file_utils.get_line_that_contains("remote software version", - log_lines))[0] + remote_version_line = file_utils.get_line_that_contains("remote software version", log_lines) + remote_major, remote_minor = re.findall(regex, remote_version_line)[0] local_major = int(local_major) local_minor = int(local_minor) remote_major = int(remote_major) @@ -217,23 +215,23 @@ def _check_for_known_errors(error_message, delete_cert, log_lines): if (remote_major < 7 or (remote_major == 7 and remote_minor < 8)) and \ (local_major > 8 or (local_major == 8 and local_minor >= 8)): - logger.warning("The OpenSSH server version in the target VM %d.%d is too old. " - "Version incompatible with OpenSSH client version %d.%d. " - "Refer to https://bugzilla.mindrot.org/show_bug.cgi?id=3351 for more information.", - remote_major, remote_minor, local_major, local_minor) + logger.warning("The OpenSSH server version in the target VM %d.%d is too old. " + "Version incompatible with OpenSSH client version %d.%d. " + "Refer to https://bugzilla.mindrot.org/show_bug.cgi?id=3351 for more information.", + remote_major, remote_minor, local_major, local_minor) elif (local_major < 7 or (local_major == 7 and local_minor < 8)) and \ (remote_major > 8 or (remote_major == 8 and remote_minor >= 8)): - logger.warning("The OpenSSH client version %d.%d is too old. " - "Version incompatible with OpenSSH server version %d.%d in the target VM. " - "Refer to https://bugzilla.mindrot.org/show_bug.cgi?id=3351 for more information.", - local_major, local_minor, remote_major, remote_minor) + logger.warning("The OpenSSH client version %d.%d is too old. " + "Version incompatible with OpenSSH server version %d.%d in the target VM. " + "Refer to https://bugzilla.mindrot.org/show_bug.cgi?id=3351 for more information.", + local_major, local_minor, remote_major, remote_minor) regex = ("{\"level\":\"fatal\",\"msg\":\"sshproxy: error copying information from the connection: " ".*\",\"time\":\".*\"}.*") if re.search(regex, error_message): - logger.error("Please make sure SSH port is allowed using \"azcmagent config list\" in the target " - "Arc Server. Ensure SSHD is running on the target machine.\n") + logger.error("Please make sure SSH port is allowed using \"azcmagent config list\" in the target " + "Arc Server. Ensure SSHD is running on the target machine.\n") def get_ssh_client_path(ssh_command="ssh", ssh_client_folder=None): diff --git a/src/ssh/azext_ssh/tests/latest/test_ssh_utils.py b/src/ssh/azext_ssh/tests/latest/test_ssh_utils.py index 26877814d08..b7d88d9af78 100644 --- a/src/ssh/azext_ssh/tests/latest/test_ssh_utils.py +++ b/src/ssh/azext_ssh/tests/latest/test_ssh_utils.py @@ -12,14 +12,13 @@ from azext_ssh import ssh_utils from azext_ssh import ssh_info -class SSHUtilsTests(unittest.TestCase): - @mock.patch.object(ssh_utils, '_start_cleanup') - @mock.patch.object(ssh_utils, '_terminate_cleanup') +class SSHUtilsTests(unittest.TestCase): + @mock.patch.object(ssh_utils, 'do_cleanup') + @mock.patch.object(ssh_utils, '_read_ssh_logs') @mock.patch.object(ssh_utils, 'get_ssh_client_path') - @mock.patch('subprocess.run') + @mock.patch('subprocess.Popen') @mock.patch('os.environ.copy') - @mock.patch('platform.system') - def test_start_ssh_connection_compute(self, mock_system, mock_copy_env, mock_call, mock_path, mock_terminatecleanup, mock_startcleanup): + def test_start_ssh_connection_compute(self, mock_copy_env, mock_call, mock_path, mock_read, mock_cleanup): op_info = ssh_info.SSHSession("rg", "vm", "ip", None, None, False, "user", None, "port", None, ['arg1', 'arg2', 'arg3'], False, "Microsof.Compute", None, None, False) op_info.public_key_file = "pub" @@ -27,28 +26,28 @@ def test_start_ssh_connection_compute(self, mock_system, mock_copy_env, mock_cal op_info.cert_file = "cert" op_info.ssh_client_folder = "client" - mock_system.return_value = 'Windows' mock_call.return_value = 0 mock_path.return_value = 'ssh' + mock_call.return_value = 'ssh_process' mock_copy_env.return_value = {'var1':'value1', 'var2':'value2', 'var3':'value3'} - mock_startcleanup.return_value = 'log', ['arg1', 'arg2', 'arg3', '-E', 'log', '-v'], 'cleanup process' - expected_command = ['ssh', 'user@ip', '-i', 'priv', '-o', 'CertificateFile=\"cert\"', '-p', 'port', 'arg1', 'arg2', 'arg3', '-E', 'log', '-v'] + expected_command = ['ssh', 'user@ip', '-i', 'priv', '-o', 'CertificateFile=\"cert\"', '-p', 'port', '-v', 'arg1', 'arg2', 'arg3'] expected_env = {'var1':'value1', 'var2':'value2', 'var3':'value3'} ssh_utils.start_ssh_connection(op_info, True, True) mock_path.assert_called_once_with('ssh', 'client') - mock_startcleanup.assert_called_with('cert', 'priv', 'pub', False, True, True, ['arg1', 'arg2', 'arg3']) - mock_call.assert_called_once_with(expected_command, shell=True, env=expected_env, stderr=mock.ANY, encoding='utf-8') - mock_terminatecleanup.assert_called_once_with(True, True, False, 'cleanup process', 'cert', 'priv', 'pub', 'log', 0) - - @mock.patch.object(ssh_utils, '_terminate_cleanup') - @mock.patch('os.environ.copy') + mock_call.assert_called_once_with(expected_command, stderr=mock.ANY, env=expected_env, encoding='utf-8') + mock_read.assert_called_once_with('ssh_process', False, op_info, True, True) + mock_cleanup.assert_called_once_with(True, True, 'cert', 'priv', 'pub') + + + @mock.patch.object(ssh_utils, 'do_cleanup') + @mock.patch.object(ssh_utils, '_read_ssh_logs') @mock.patch.object(ssh_utils, 'get_ssh_client_path') - @mock.patch('subprocess.run') + @mock.patch('os.environ.copy') + @mock.patch('subprocess.Popen') @mock.patch('azext_ssh.custom.connectivity_utils.format_relay_info_string') - @mock.patch('platform.system') - def test_start_ssh_connection_arc(self, mock_system, mock_relay_str, mock_call, mock_path, mock_copy_env, mock_terminatecleanup): + def test_start_ssh_connection_arc(self, mock_relay_str, mock_call, mock_copy_env, mock_path, mock_read, mock_cleanup): op_info = ssh_info.SSHSession("rg", "vm", None, None, None, False, "user", None, "port", None, ['arg1'], False, "Microsoft.HybridCompute", None, None, False) op_info.public_key_file = "pub" @@ -58,22 +57,22 @@ def test_start_ssh_connection_arc(self, mock_system, mock_relay_str, mock_call, op_info.proxy_path = "proxy" op_info.relay_info = "relay" - mock_system.return_value = 'Linux' - mock_call.return_value = 0 + mock_call.return_value = 'ssh_process' mock_relay_str.return_value = 'relay_string' mock_copy_env.return_value = {'var1':'value1', 'var2':'value2', 'var3':'value3'} mock_path.return_value = 'ssh' - expected_command = ['ssh', 'user@vm', '-o', 'ProxyCommand=\"proxy\" -p port', '-i', 'priv', '-o', 'CertificateFile=\"cert\"', 'arg1'] + expected_command = ['ssh', 'user@vm', '-o', 'ProxyCommand=\"proxy\" -p port', '-i', 'priv', '-o', 'CertificateFile=\"cert\"', '-v', 'arg1'] expected_env = {'var1':'value1', 'var2':'value2', 'var3':'value3', 'SSHPROXY_RELAY_INFO':'relay_string'} ssh_utils.start_ssh_connection(op_info, False, False) mock_relay_str.assert_called_once_with('relay') mock_path.assert_called_once_with('ssh', 'client') - mock_call.assert_called_once_with(expected_command, shell=False, env=expected_env, stderr=mock.ANY, encoding='utf-8') - mock_terminatecleanup.assert_called_once_with(False, False, False, None, 'cert', 'priv', 'pub', None, 0) - + mock_call.assert_called_once_with(expected_command, stderr=mock.ANY, env=expected_env, encoding='utf-8') + mock_cleanup.assert_called_once_with(False, False, 'cert', 'priv', 'pub') + mock_read.assert_called_once_with('ssh_process', False, op_info, False, False) + @mock.patch.object(ssh_utils, '_issue_config_cleanup_warning') @mock.patch('os.path.abspath') def test_write_ssh_config_ip_and_vm_compute_append(self, mock_abspath, mock_warning): From f991d562e0a270321345820d676626129181c355 Mon Sep 17 00:00:00 2001 From: Vivian Thiebaut Date: Fri, 26 Aug 2022 10:22:12 -0400 Subject: [PATCH 4/9] banner changes --- src/ssh/azext_ssh/ssh_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/ssh/azext_ssh/ssh_utils.py b/src/ssh/azext_ssh/ssh_utils.py index 081ee46b0d3..a2248875802 100644 --- a/src/ssh/azext_ssh/ssh_utils.py +++ b/src/ssh/azext_ssh/ssh_utils.py @@ -94,11 +94,13 @@ def _read_ssh_logs(ssh_sub, print_ssh_logs, op_info, delete_cert, delete_keys): not next_line.startswith('Transferred: sent') and \ not next_line.startswith('Bytes per second: sent') and \ not next_line.startswith('OpenSSH_'): - print(next_line, end='', file=sys.stderr) + #print(next_line, end='', file=sys.stderr) + sys.stderr.write(next_line) _check_for_known_errors(next_line, delete_cert, log_list) elif print_ssh_logs: - print(next_line, end='', file=sys.stderr) + #print(next_line, end='', file=sys.stderr) + sys.stderr.write(next_line) log_list.append(next_line) @@ -107,7 +109,6 @@ def _read_ssh_logs(ssh_sub, print_ssh_logs, op_info, delete_cert, delete_keys): do_cleanup(delete_keys, delete_cert, op_info.cert_file, op_info.private_key_file, op_info.public_key_file) next_line = ssh_sub.stderr.readline() - ssh_sub.wait() return ssh_sucess From 7f10b83ad01377d15bb1797c3fd34470cd1d6a24 Mon Sep 17 00:00:00 2001 From: Vivian Thiebaut Date: Thu, 8 Sep 2022 10:54:32 -0400 Subject: [PATCH 5/9] handle stderr differently depending on os and resource type --- src/ssh/azext_ssh/ssh_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/ssh/azext_ssh/ssh_utils.py b/src/ssh/azext_ssh/ssh_utils.py index a2248875802..e01f0ee30da 100644 --- a/src/ssh/azext_ssh/ssh_utils.py +++ b/src/ssh/azext_ssh/ssh_utils.py @@ -80,7 +80,6 @@ def start_ssh_connection(op_info, delete_keys, delete_cert): def _read_ssh_logs(ssh_sub, print_ssh_logs, op_info, delete_cert, delete_keys): log_list = [] - ssh_sucess = False next_line = ssh_sub.stderr.readline() while next_line: @@ -110,7 +109,6 @@ def _read_ssh_logs(ssh_sub, print_ssh_logs, op_info, delete_cert, delete_keys): next_line = ssh_sub.stderr.readline() ssh_sub.wait() - return ssh_sucess def write_ssh_config(config_info, delete_keys, delete_cert): From d8351f327c72fcadcb71ccb79a4bed5a0db94e18 Mon Sep 17 00:00:00 2001 From: Vivian Thiebaut Date: Wed, 14 Sep 2022 17:22:12 -0400 Subject: [PATCH 6/9] Mitigate weird issues on linux --- src/ssh/azext_ssh/ssh_utils.py | 45 ++++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 18 deletions(-) diff --git a/src/ssh/azext_ssh/ssh_utils.py b/src/ssh/azext_ssh/ssh_utils.py index e01f0ee30da..8449f0dc787 100644 --- a/src/ssh/azext_ssh/ssh_utils.py +++ b/src/ssh/azext_ssh/ssh_utils.py @@ -57,15 +57,18 @@ def start_ssh_connection(op_info, delete_keys, delete_cert): logger.debug("Running ssh command %s", ' '.join(command)) try: - # pylint: disable=consider-using-with - ssh_process = subprocess.Popen(command, stderr=subprocess.PIPE, env=env, encoding='utf-8') + # In these cases, there is no reason to read the logs. Not redirect stderr to avoid complications. + if (platform.system() != 'Windows' and not delete_cert or\ + platform.system() == 'Windows' and not op_info.is_arc() and not delete_cert): + ssh_process = subprocess.Popen(command, env=env, encoding='utf-8') + else: + ssh_process = subprocess.Popen(command, stderr=subprocess.PIPE, env=env, encoding='utf-8') + _read_ssh_logs(ssh_process, print_ssh_logs, op_info, delete_cert, delete_keys) except OSError as e: colorama.init() raise azclierror.BadRequestError(f"Failed to run ssh command with error: {str(e)}.", const.RECOMMENDATION_SSH_CLIENT_NOT_FOUND) - _read_ssh_logs(ssh_process, print_ssh_logs, op_info, delete_cert, delete_keys) - connection_duration = (time.time() - connection_duration) / 60 ssh_connection_data = {'Context.Default.AzureCLI.SSHConnectionDurationInMinutes': connection_duration} if connection_status and connection_status.returncode == 0: @@ -80,34 +83,40 @@ def start_ssh_connection(op_info, delete_keys, delete_cert): def _read_ssh_logs(ssh_sub, print_ssh_logs, op_info, delete_cert, delete_keys): log_list = [] + connection_established = False + t0 = time.time() next_line = ssh_sub.stderr.readline() while next_line: if "debug1:" not in next_line and \ "debug2:" not in next_line and \ "debug3:" not in next_line: - - # Filter out known logs that don't start with "debug", - # but that are not useful error messages or banners. - if not next_line.startswith('Authenticated to') and \ - not next_line.startswith('Transferred: sent') and \ - not next_line.startswith('Bytes per second: sent') and \ - not next_line.startswith('OpenSSH_'): - #print(next_line, end='', file=sys.stderr) - sys.stderr.write(next_line) - + sys.stderr.write(next_line) _check_for_known_errors(next_line, delete_cert, log_list) elif print_ssh_logs: - #print(next_line, end='', file=sys.stderr) - sys.stderr.write(next_line) - - log_list.append(next_line) + # with this approach logs don't get printed very gracefully after connection was + # established in linux. Save logs to print them once connection closes. + if platform.system() == 'Windows' or not connection_established: + sys.stderr.write(next_line) + else: + log_list.append(next_line) + # Credentials are deleted once we verify from the logs that the connection was established or + # after 2 minutes from the beginning of the connection. if "debug1: Entering interactive session." in next_line: logger.debug("SSH Connection estalished succesfully.") + connection_established = True + do_cleanup(delete_keys, delete_cert, op_info.cert_file, op_info.private_key_file, op_info.public_key_file) + + if not connection_established and \ + time.time() - t0 > const.CLEANUP_TOTAL_TIME_LIMIT_IN_SECONDS: do_cleanup(delete_keys, delete_cert, op_info.cert_file, op_info.private_key_file, op_info.public_key_file) next_line = ssh_sub.stderr.readline() + + for line in log_list: + sys.stderr.write(line) + ssh_sub.wait() From dea6b872abf4f5841a217f799f11e1bcabf1c264 Mon Sep 17 00:00:00 2001 From: Vivian Thiebaut Date: Mon, 19 Sep 2022 18:47:22 -0400 Subject: [PATCH 7/9] No longer redirect stderr if user expecting ssh logs. When stderr not redirected, but credentials need to be deleted, wait 2 minutes --- src/ssh/azext_ssh/custom.py | 5 +- src/ssh/azext_ssh/rdp_utils.py | 11 +- src/ssh/azext_ssh/ssh_utils.py | 122 +++++++++--------- .../azext_ssh/tests/latest/test_ssh_utils.py | 102 +++++++++++++-- 4 files changed, 162 insertions(+), 78 deletions(-) diff --git a/src/ssh/azext_ssh/custom.py b/src/ssh/azext_ssh/custom.py index 9bcdbfad3bd..234ee452b96 100644 --- a/src/ssh/azext_ssh/custom.py +++ b/src/ssh/azext_ssh/custom.py @@ -44,9 +44,6 @@ def ssh_vm(cmd, resource_group_name=None, vm_name=None, ssh_ip=None, public_key_ if '--debug' in cmd.cli_ctx.data['safe_params'] and set(['-v', '-vv', '-vvv']).isdisjoint(ssh_args): ssh_args = ['-vvv'] if not ssh_args else ['-vvv'] + ssh_args - if '-E' in ssh_args: - raise azclierror.ArgumentUsageError("The -E SSH agrgument is not supported.") - _assert_args(resource_group_name, vm_name, ssh_ip, resource_type, cert_file, local_user) # all credentials for this command are saved in temp folder and deleted at the end of execution. @@ -193,7 +190,7 @@ def _do_ssh_op(cmd, op_info, op_call): op_info.private_key_file + ', ' if delete_keys else "", op_info.public_key_file + ', ' if delete_keys else "", op_info.cert_file if delete_cert else "") - ssh_utils.do_cleanup(delete_keys, delete_cert, op_info.cert_file, + ssh_utils.do_cleanup(delete_keys, delete_cert, op_info.delete_credentials, op_info.cert_file, op_info.private_key_file, op_info.public_key_file) raise e diff --git a/src/ssh/azext_ssh/rdp_utils.py b/src/ssh/azext_ssh/rdp_utils.py index 33f2ed14e18..bf56b900118 100644 --- a/src/ssh/azext_ssh/rdp_utils.py +++ b/src/ssh/azext_ssh/rdp_utils.py @@ -43,8 +43,8 @@ def start_rdp_connection(ssh_info, delete_keys, delete_cert): ssh_process, print_ssh_logs = start_ssh_tunnel(ssh_info) ssh_connection_t0 = time.time() ssh_success, log_list = wait_for_ssh_connection(ssh_process, print_ssh_logs) - ssh_utils.do_cleanup(delete_keys, delete_cert, ssh_info.cert_file, ssh_info.private_key_file, - ssh_info.public_key_file) + ssh_utils.do_cleanup(delete_keys, delete_cert, ssh_info.delete_credentials, ssh_info.cert_file, + ssh_info.private_key_file, ssh_info.public_key_file) if ssh_success and ssh_process.poll() is None: call_rdp(local_port) @@ -56,8 +56,8 @@ def start_rdp_connection(ssh_info, delete_keys, delete_cert): telemetry.add_extension_event('ssh', ssh_connection_data) terminate_ssh(ssh_process, log_list, print_ssh_logs) - ssh_utils.do_cleanup(delete_keys, delete_cert, ssh_info.cert_file, ssh_info.private_key_file, - ssh_info.public_key_file) + ssh_utils.do_cleanup(delete_keys, delete_cert, ssh_info.delete_credentials, ssh_info.cert_file, + ssh_info.private_key_file, ssh_info.public_key_file) if delete_keys: # This is only true if keys were generated, so they must be in a temp folder. temp_dir = os.path.dirname(ssh_info.cert_file) @@ -107,6 +107,9 @@ def start_ssh_tunnel(op_info): else: op_info.ssh_args = ['-v'] + op_info.ssh_args + if '-E' in op_info.ssh_args: + raise azclierror.BadRequestError("Can't use -E ssh parameter when using --rdp") + command = [ssh_utils.get_ssh_client_path('ssh', op_info.ssh_client_folder), op_info.get_host()] command = command + op_info.build_args() + op_info.ssh_args diff --git a/src/ssh/azext_ssh/ssh_utils.py b/src/ssh/azext_ssh/ssh_utils.py index 8449f0dc787..e79709a8221 100644 --- a/src/ssh/azext_ssh/ssh_utils.py +++ b/src/ssh/azext_ssh/ssh_utils.py @@ -25,45 +25,42 @@ def start_ssh_connection(op_info, delete_keys, delete_cert): try: - # Initialize these so that if something fails in the try block before these - # are initialized, then the finally block won't fail. - - connection_status = None - ssh_arg_list = [] if op_info.ssh_args: ssh_arg_list = op_info.ssh_args - print_ssh_logs = False - if not set(['-v', '-vv', '-vvv']).isdisjoint(ssh_arg_list): - print_ssh_logs = True - else: + # Redirecting stderr: + # 1. Read SSH logs to determine if authentication was successful so credentials can be deleted + # 2. Read SSHProxy error messages to print friendly error messages for well known errors. + # On Linux when connecting to a local user on a host with a banner, output gets messed up if stderr redirected. + # If user expects logs to be printed, do not redirect logs. In some ocasions output gets messed up. + is_local_user_on_linux = (platform.system() != 'Windows' and not delete_cert) + redirect_stderr = set(['-v', '-vv', '-vvv']).isdisjoint(ssh_arg_list) and \ + (op_info.is_arc or delete_cert or op_info.delete_credentials) and \ + not is_local_user_on_linux + + if redirect_stderr: ssh_arg_list = ['-v'] + ssh_arg_list env = os.environ.copy() if op_info.is_arc(): env['SSHPROXY_RELAY_INFO'] = connectivity_utils.format_relay_info_string(op_info.relay_info) - # Get ssh client before starting the clean up process in case there is an error in getting client. command = [get_ssh_client_path('ssh', op_info.ssh_client_folder), op_info.get_host()] - if not op_info.cert_file and not op_info.private_key_file: - # In this case, even if delete_credentials is true, there is nothing to clean-up. - op_info.delete_credentials = False - command = command + op_info.build_args() + ssh_arg_list connection_duration = time.time() - logger.debug("Running ssh command %s", ' '.join(command)) + logger.warning("Running ssh command %s", ' '.join(command)) try: - # In these cases, there is no reason to read the logs. Not redirect stderr to avoid complications. - if (platform.system() != 'Windows' and not delete_cert or\ - platform.system() == 'Windows' and not op_info.is_arc() and not delete_cert): - ssh_process = subprocess.Popen(command, env=env, encoding='utf-8') - else: + # pylint: disable=consider-using-with + if redirect_stderr: ssh_process = subprocess.Popen(command, stderr=subprocess.PIPE, env=env, encoding='utf-8') - _read_ssh_logs(ssh_process, print_ssh_logs, op_info, delete_cert, delete_keys) + _read_ssh_logs(ssh_process, op_info, delete_cert, delete_keys) + else: + ssh_process = subprocess.Popen(command, env=env, encoding='utf-8') + _wait_to_delete_credentials(ssh_process, op_info, delete_cert, delete_keys) except OSError as e: colorama.init() raise azclierror.BadRequestError(f"Failed to run ssh command with error: {str(e)}.", @@ -71,67 +68,74 @@ def start_ssh_connection(op_info, delete_keys, delete_cert): connection_duration = (time.time() - connection_duration) / 60 ssh_connection_data = {'Context.Default.AzureCLI.SSHConnectionDurationInMinutes': connection_duration} - if connection_status and connection_status.returncode == 0: + if ssh_process.poll() == 0: ssh_connection_data['Context.Default.AzureCLI.SSHConnectionStatus'] = "Success" telemetry.add_extension_event('ssh', ssh_connection_data) finally: # Even if something fails between the creation of the credentials and the end of the ssh connection, we - # want to make sure that all credentials are cleaned up, and that the clean up process is terminated. - do_cleanup(delete_keys, delete_cert, op_info.cert_file, op_info.private_key_file, op_info.public_key_file) + # want to make sure that all credentials are cleaned up. + do_cleanup(delete_keys, delete_cert, op_info.delete_credentials, + op_info.cert_file, op_info.private_key_file, op_info.public_key_file) -def _read_ssh_logs(ssh_sub, print_ssh_logs, op_info, delete_cert, delete_keys): +def write_ssh_config(config_info, delete_keys, delete_cert): + # if delete cert is true, then this is AAD login. + config_text = config_info.get_config_text(delete_cert) + _issue_config_cleanup_warning(delete_cert, delete_keys, config_info.is_arc(), + config_info.cert_file, config_info.relay_info_path, + config_info.ssh_client_folder) + if config_info.overwrite: + mode = 'w' + else: + mode = 'a' + with open(config_info.config_path, mode, encoding='utf-8') as f: + f.write('\n'.join(config_text)) + + +def _read_ssh_logs(ssh_sub, op_info, delete_cert, delete_keys): log_list = [] connection_established = False t0 = time.time() next_line = ssh_sub.stderr.readline() while next_line: - if "debug1:" not in next_line and \ - "debug2:" not in next_line and \ - "debug3:" not in next_line: + log_list.append(next_line) + if not next_line.startswith("debug1:") and \ + not next_line.startswith("debug2:") and \ + not next_line.startswith("debug3:") and \ + not next_line.startswith("Authenticated "): sys.stderr.write(next_line) _check_for_known_errors(next_line, delete_cert, log_list) - elif print_ssh_logs: - # with this approach logs don't get printed very gracefully after connection was - # established in linux. Save logs to print them once connection closes. - if platform.system() == 'Windows' or not connection_established: - sys.stderr.write(next_line) - else: - log_list.append(next_line) - # Credentials are deleted once we verify from the logs that the connection was established or - # after 2 minutes from the beginning of the connection. if "debug1: Entering interactive session." in next_line: - logger.debug("SSH Connection estalished succesfully.") connection_established = True - do_cleanup(delete_keys, delete_cert, op_info.cert_file, op_info.private_key_file, op_info.public_key_file) - + do_cleanup(delete_keys, delete_cert, op_info.delete_credentials, + op_info.cert_file, op_info.private_key_file, op_info.public_key_file) + if not connection_established and \ time.time() - t0 > const.CLEANUP_TOTAL_TIME_LIMIT_IN_SECONDS: - do_cleanup(delete_keys, delete_cert, op_info.cert_file, op_info.private_key_file, op_info.public_key_file) + do_cleanup(delete_keys, delete_cert, op_info.delete_credentials, + op_info.cert_file, op_info.private_key_file, op_info.public_key_file) next_line = ssh_sub.stderr.readline() - - for line in log_list: - sys.stderr.write(line) ssh_sub.wait() -def write_ssh_config(config_info, delete_keys, delete_cert): - # if delete cert is true, then this is AAD login. - config_text = config_info.get_config_text(delete_cert) - _issue_config_cleanup_warning(delete_cert, delete_keys, config_info.is_arc(), - config_info.cert_file, config_info.relay_info_path, - config_info.ssh_client_folder) - if config_info.overwrite: - mode = 'w' - else: - mode = 'a' - with open(config_info.config_path, mode, encoding='utf-8') as f: - f.write('\n'.join(config_text)) +def _wait_to_delete_credentials(ssh_sub, op_info, delete_cert, delete_keys): + # wait for 2 minutes. If the process isn't closed until then, delete credentials. + if delete_cert or op_info.delete_credentials: + t0 = time.time() + while (time.time() - t0) < const.CLEANUP_TOTAL_TIME_LIMIT_IN_SECONDS: + if ssh_sub.poll() is not None: + break + time.sleep(1) + + do_cleanup(delete_keys, delete_cert, op_info.delete_credentials, + op_info.cert_file, op_info.private_key_file, op_info.public_key_file) + + ssh_sub.wait() def create_ssh_keyfile(private_key_file, ssh_client_folder=None): @@ -296,12 +300,12 @@ def get_ssh_client_path(ssh_command="ssh", ssh_client_folder=None): return ssh_path -def do_cleanup(delete_keys, delete_cert, cert_file, private_key, public_key): - if delete_keys and private_key: +def do_cleanup(delete_keys, delete_cert, delete_credentials, cert_file, private_key, public_key): + if (delete_keys or delete_credentials) and private_key: file_utils.delete_file(private_key, f"Couldn't delete private key {private_key}. ", True) if delete_keys and public_key: file_utils.delete_file(public_key, f"Couldn't delete public key {public_key}. ", True) - if delete_cert and cert_file: + if (delete_cert or delete_credentials) and cert_file: file_utils.delete_file(cert_file, f"Couldn't delete certificate {cert_file}. ", True) if delete_keys: # This is only true if keys were generated, so they must be in a temp folder. diff --git a/src/ssh/azext_ssh/tests/latest/test_ssh_utils.py b/src/ssh/azext_ssh/tests/latest/test_ssh_utils.py index b7d88d9af78..d934c0b26c4 100644 --- a/src/ssh/azext_ssh/tests/latest/test_ssh_utils.py +++ b/src/ssh/azext_ssh/tests/latest/test_ssh_utils.py @@ -12,23 +12,29 @@ from azext_ssh import ssh_utils from azext_ssh import ssh_info + class SSHUtilsTests(unittest.TestCase): @mock.patch.object(ssh_utils, 'do_cleanup') @mock.patch.object(ssh_utils, '_read_ssh_logs') @mock.patch.object(ssh_utils, 'get_ssh_client_path') @mock.patch('subprocess.Popen') @mock.patch('os.environ.copy') - def test_start_ssh_connection_compute(self, mock_copy_env, mock_call, mock_path, mock_read, mock_cleanup): + @mock.patch('platform.system') + def test_start_ssh_connection_compute_aad_windows(self, mock_system, mock_copy_env, mock_call, mock_path, mock_read, mock_cleanup): - op_info = ssh_info.SSHSession("rg", "vm", "ip", None, None, False, "user", None, "port", None, ['arg1', 'arg2', 'arg3'], False, "Microsof.Compute", None, None, False) + op_info = ssh_info.SSHSession("rg", "vm", "ip", None, None, False, "user", None, "port", None, ['arg1', 'arg2', 'arg3'], False, "Microsof.Compute/virtualMachines", None, None, False) op_info.public_key_file = "pub" op_info.private_key_file = "priv" op_info.cert_file = "cert" op_info.ssh_client_folder = "client" + ssh_process = mock.Mock() + ssh_process.poll.return_value = 0 + + mock_system.return_value = 'Windows' mock_call.return_value = 0 mock_path.return_value = 'ssh' - mock_call.return_value = 'ssh_process' + mock_call.return_value = ssh_process mock_copy_env.return_value = {'var1':'value1', 'var2':'value2', 'var3':'value3'} expected_command = ['ssh', 'user@ip', '-i', 'priv', '-o', 'CertificateFile=\"cert\"', '-p', 'port', '-v', 'arg1', 'arg2', 'arg3'] expected_env = {'var1':'value1', 'var2':'value2', 'var3':'value3'} @@ -37,9 +43,41 @@ def test_start_ssh_connection_compute(self, mock_copy_env, mock_call, mock_path, mock_path.assert_called_once_with('ssh', 'client') mock_call.assert_called_once_with(expected_command, stderr=mock.ANY, env=expected_env, encoding='utf-8') - mock_read.assert_called_once_with('ssh_process', False, op_info, True, True) - mock_cleanup.assert_called_once_with(True, True, 'cert', 'priv', 'pub') + mock_read.assert_called_once_with(ssh_process, op_info, True, True) + mock_cleanup.assert_called_once_with(True, True, False, 'cert', 'priv', 'pub') + @mock.patch.object(ssh_utils, 'do_cleanup') + @mock.patch.object(ssh_utils, '_wait_to_delete_credentials') + @mock.patch.object(ssh_utils, 'get_ssh_client_path') + @mock.patch('subprocess.Popen') + @mock.patch('os.environ.copy') + @mock.patch('platform.system') + def test_start_ssh_connection_compute_local_linux(self, mock_system, mock_copy_env, mock_call, mock_path, mock_wait, mock_cleanup): + + op_info = ssh_info.SSHSession("rg", "vm", "ip", None, None, False, "user", None, "port", None, ['arg1', 'arg2', 'arg3'], False, "Microsof.Compute", None, None, False) + op_info.public_key_file = "pub" + op_info.private_key_file = "priv" + op_info.cert_file = "cert" + op_info.ssh_client_folder = "client" + + ssh_process = mock.Mock() + ssh_process.poll.return_value = 0 + + mock_system.return_value = 'Linux' + mock_call.return_value = 0 + mock_path.return_value = 'ssh' + mock_call.return_value = ssh_process + mock_copy_env.return_value = {'var1':'value1', 'var2':'value2', 'var3':'value3'} + expected_command = ['ssh', 'user@ip', '-i', 'priv', '-o', 'CertificateFile=\"cert\"', '-p', 'port', 'arg1', 'arg2', 'arg3'] + expected_env = {'var1':'value1', 'var2':'value2', 'var3':'value3'} + + ssh_utils.start_ssh_connection(op_info, False, False) + + mock_path.assert_called_once_with('ssh', 'client') + mock_call.assert_called_once_with(expected_command, env=expected_env, encoding='utf-8') + mock_wait.assert_called_once_with(ssh_process, op_info, False, False) + mock_cleanup.assert_called_once_with(False, False, False, 'cert', 'priv', 'pub') + @mock.patch.object(ssh_utils, 'do_cleanup') @mock.patch.object(ssh_utils, '_read_ssh_logs') @@ -47,7 +85,8 @@ def test_start_ssh_connection_compute(self, mock_copy_env, mock_call, mock_path, @mock.patch('os.environ.copy') @mock.patch('subprocess.Popen') @mock.patch('azext_ssh.custom.connectivity_utils.format_relay_info_string') - def test_start_ssh_connection_arc(self, mock_relay_str, mock_call, mock_copy_env, mock_path, mock_read, mock_cleanup): + @mock.patch('platform.system') + def test_start_ssh_connection_arc_aad_windows(self, mock_platform, mock_relay_str, mock_call, mock_copy_env, mock_path, mock_read, mock_cleanup): op_info = ssh_info.SSHSession("rg", "vm", None, None, None, False, "user", None, "port", None, ['arg1'], False, "Microsoft.HybridCompute", None, None, False) op_info.public_key_file = "pub" @@ -56,22 +95,63 @@ def test_start_ssh_connection_arc(self, mock_relay_str, mock_call, mock_copy_env op_info.ssh_client_folder = "client" op_info.proxy_path = "proxy" op_info.relay_info = "relay" + + ssh_process = mock.Mock() + ssh_process.poll.return_value = 0 - mock_call.return_value = 'ssh_process' + mock_platform.return_value = 'Windows' + mock_call.return_value = ssh_process mock_relay_str.return_value = 'relay_string' mock_copy_env.return_value = {'var1':'value1', 'var2':'value2', 'var3':'value3'} mock_path.return_value = 'ssh' expected_command = ['ssh', 'user@vm', '-o', 'ProxyCommand=\"proxy\" -p port', '-i', 'priv', '-o', 'CertificateFile=\"cert\"', '-v', 'arg1'] expected_env = {'var1':'value1', 'var2':'value2', 'var3':'value3', 'SSHPROXY_RELAY_INFO':'relay_string'} - ssh_utils.start_ssh_connection(op_info, False, False) + ssh_utils.start_ssh_connection(op_info, True, True) mock_relay_str.assert_called_once_with('relay') mock_path.assert_called_once_with('ssh', 'client') mock_call.assert_called_once_with(expected_command, stderr=mock.ANY, env=expected_env, encoding='utf-8') - mock_cleanup.assert_called_once_with(False, False, 'cert', 'priv', 'pub') - mock_read.assert_called_once_with('ssh_process', False, op_info, False, False) - + mock_cleanup.assert_called_once_with(True, True, False, 'cert', 'priv', 'pub') + mock_read.assert_called_once_with(ssh_process, op_info, True, True) + + + @mock.patch.object(ssh_utils, 'do_cleanup') + @mock.patch.object(ssh_utils, '_wait_to_delete_credentials') + @mock.patch.object(ssh_utils, 'get_ssh_client_path') + @mock.patch('os.environ.copy') + @mock.patch('subprocess.Popen') + @mock.patch('azext_ssh.custom.connectivity_utils.format_relay_info_string') + @mock.patch('platform.system') + def test_start_ssh_connection_arc_local_linux(self, mock_platform, mock_relay_str, mock_call, mock_copy_env, mock_path, mock_wait, mock_cleanup): + + op_info = ssh_info.SSHSession("rg", "vm", None, None, None, False, "user", None, "port", None, ['arg1'], False, "Microsoft.HybridCompute", None, None, False) + op_info.public_key_file = "pub" + op_info.private_key_file = "priv" + op_info.cert_file = "cert" + op_info.ssh_client_folder = "client" + op_info.proxy_path = "proxy" + op_info.relay_info = "relay" + + ssh_process = mock.Mock() + ssh_process.poll.return_value = 0 + + mock_platform.return_value = 'Linux' + mock_call.return_value = ssh_process + mock_relay_str.return_value = 'relay_string' + mock_copy_env.return_value = {'var1':'value1', 'var2':'value2', 'var3':'value3'} + mock_path.return_value = 'ssh' + expected_command = ['ssh', 'user@vm', '-o', 'ProxyCommand=\"proxy\" -p port', '-i', 'priv', '-o', 'CertificateFile=\"cert\"', 'arg1'] + expected_env = {'var1':'value1', 'var2':'value2', 'var3':'value3', 'SSHPROXY_RELAY_INFO':'relay_string'} + + ssh_utils.start_ssh_connection(op_info, False, False) + + mock_relay_str.assert_called_once_with('relay') + mock_path.assert_called_once_with('ssh', 'client') + mock_call.assert_called_once_with(expected_command, env=expected_env, encoding='utf-8') + mock_cleanup.assert_called_once_with(False, False, False, 'cert', 'priv', 'pub') + mock_wait.assert_called_once_with(ssh_process, op_info, False, False) + @mock.patch.object(ssh_utils, '_issue_config_cleanup_warning') @mock.patch('os.path.abspath') From 73decda345b379e7a05d20b6c1c13b49b1ace0a2 Mon Sep 17 00:00:00 2001 From: Vivian Thiebaut Date: Mon, 19 Sep 2022 18:56:18 -0400 Subject: [PATCH 8/9] remove warning --- src/ssh/azext_ssh/ssh_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ssh/azext_ssh/ssh_utils.py b/src/ssh/azext_ssh/ssh_utils.py index e79709a8221..2e510b2ea4b 100644 --- a/src/ssh/azext_ssh/ssh_utils.py +++ b/src/ssh/azext_ssh/ssh_utils.py @@ -51,7 +51,7 @@ def start_ssh_connection(op_info, delete_keys, delete_cert): command = command + op_info.build_args() + ssh_arg_list connection_duration = time.time() - logger.warning("Running ssh command %s", ' '.join(command)) + logger.debug("Running ssh command %s", ' '.join(command)) try: # pylint: disable=consider-using-with From a93c6f86c4b65830efed15158e84e8f9dc21622c Mon Sep 17 00:00:00 2001 From: Vivian Thiebaut Date: Thu, 20 Oct 2022 18:49:33 -0400 Subject: [PATCH 9/9] Resolve conflicts, fix unit tests --- src/ssh/HISTORY.md | 2 +- src/ssh/azext_ssh/ssh_utils.py | 2 +- src/ssh/azext_ssh/tests/latest/test_ssh_utils.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/ssh/HISTORY.md b/src/ssh/HISTORY.md index 94dc6492a53..6e51f477016 100644 --- a/src/ssh/HISTORY.md +++ b/src/ssh/HISTORY.md @@ -2,7 +2,7 @@ Release History =============== 1.1.3 ----- -* Fix bug where extension prints ssh banners after the connection is closed. +* [bug fix] SSH Banners are printed before authentication. 1.1.2 ----- diff --git a/src/ssh/azext_ssh/ssh_utils.py b/src/ssh/azext_ssh/ssh_utils.py index ffda24fb7ba..b889ee7cbc6 100644 --- a/src/ssh/azext_ssh/ssh_utils.py +++ b/src/ssh/azext_ssh/ssh_utils.py @@ -308,7 +308,7 @@ def do_cleanup(delete_keys, delete_cert, delete_credentials, cert_file, private_ file_utils.delete_file(public_key, f"Couldn't delete public key {public_key}. ", True) if (delete_cert or delete_credentials) and cert_file: file_utils.delete_file(cert_file, f"Couldn't delete certificate {cert_file}. ", True) - if delete_keys: + if delete_keys and cert_file: # This is only true if keys were generated, so they must be in a temp folder. temp_dir = os.path.dirname(cert_file) file_utils.delete_folder(temp_dir, f"Couldn't delete temporary folder {temp_dir}", True) diff --git a/src/ssh/azext_ssh/tests/latest/test_ssh_utils.py b/src/ssh/azext_ssh/tests/latest/test_ssh_utils.py index 7bdd8e5a9d7..27c9f45e54d 100644 --- a/src/ssh/azext_ssh/tests/latest/test_ssh_utils.py +++ b/src/ssh/azext_ssh/tests/latest/test_ssh_utils.py @@ -36,7 +36,7 @@ def test_start_ssh_connection_compute_aad_windows(self, mock_system, mock_copy_e mock_path.return_value = 'ssh' mock_call.return_value = ssh_process mock_copy_env.return_value = {'var1':'value1', 'var2':'value2', 'var3':'value3'} - expected_command = ['ssh', 'ip', '-l', 'user', '-i', 'priv', '-o', 'CertificateFile=\"cert\"', '-p', 'port', 'arg1', 'arg2', 'arg3', '-E', 'log', '-v'] + expected_command = ['ssh', 'ip', '-l', 'user', '-i', 'priv', '-o', 'CertificateFile=\"cert\"', '-p', 'port', '-v', 'arg1', 'arg2', 'arg3'] expected_env = {'var1':'value1', 'var2':'value2', 'var3':'value3'} ssh_utils.start_ssh_connection(op_info, True, True) @@ -68,7 +68,7 @@ def test_start_ssh_connection_compute_local_linux(self, mock_system, mock_copy_e mock_path.return_value = 'ssh' mock_call.return_value = ssh_process mock_copy_env.return_value = {'var1':'value1', 'var2':'value2', 'var3':'value3'} - expected_command = ['ssh', 'user@ip', '-i', 'priv', '-o', 'CertificateFile=\"cert\"', '-p', 'port', 'arg1', 'arg2', 'arg3'] + expected_command = ['ssh', 'ip', '-l', 'user', '-i', 'priv', '-o', 'CertificateFile=\"cert\"', '-p', 'port', 'arg1', 'arg2', 'arg3'] expected_env = {'var1':'value1', 'var2':'value2', 'var3':'value3'} ssh_utils.start_ssh_connection(op_info, False, False) @@ -104,7 +104,7 @@ def test_start_ssh_connection_arc_aad_windows(self, mock_platform, mock_relay_st mock_relay_str.return_value = 'relay_string' mock_copy_env.return_value = {'var1':'value1', 'var2':'value2', 'var3':'value3'} mock_path.return_value = 'ssh' - expected_command = ['ssh', 'user@vm', '-o', 'ProxyCommand=\"proxy\" -p port', '-i', 'priv', '-o', 'CertificateFile=\"cert\"', '-v', 'arg1'] + expected_command = ['ssh', 'vm', '-l', 'user', '-o', 'ProxyCommand=\"proxy\" -p port', '-i', 'priv', '-o', 'CertificateFile=\"cert\"', '-v', 'arg1'] expected_env = {'var1':'value1', 'var2':'value2', 'var3':'value3', 'SSHPROXY_RELAY_INFO':'relay_string'} ssh_utils.start_ssh_connection(op_info, True, True)