diff --git a/src/ssh/azext_ssh/_help.py b/src/ssh/azext_ssh/_help.py index 78b3999a4b8..e48def87252 100644 --- a/src/ssh/azext_ssh/_help.py +++ b/src/ssh/azext_ssh/_help.py @@ -51,6 +51,10 @@ - name: Give a SSH Client Folder to use the ssh executables in that folder, like ssh-keygen.exe and ssh.exe. If not provided, the extension attempts to use pre-installed OpenSSH client (on Windows, extension looks for pre-installed executables under C:\\Windows\\System32\\OpenSSH). text: | az ssh vm --resource-group myResourceGroup --name myVM --ssh-client-folder "C:\\Program Files\\OpenSSH" + + - name: Open RDP connection over SSH. Useful for connecting via RDP to Arc Servers with no public IP address. Currently only supported for Windows clients. + text: | + az ssh vm --resource-group myResourceGroup --name myVM --local-user username --rdp """ helps['ssh config'] = """ @@ -144,4 +148,8 @@ - name: Give a SSH Client Folder to use the ssh executables in that folder, like ssh-keygen.exe and ssh.exe. If not provided, the extension attempts to use pre-installed OpenSSH client (on Windows, extension looks for pre-installed executables under C:\\Windows\\System32\\OpenSSH). text: | az ssh arc --resource-group myResourceGroup --name myMachine --ssh-client-folder "C:\\Program Files\\OpenSSH" + + - name: Open RDP connection over SSH. Useful for connecting via RDP to Arc Servers with no public IP address. Currently only supported for Windows clients. + text: | + az ssh arc --resource-group myResourceGroup --name myVM --local-user username --rdp """ diff --git a/src/ssh/azext_ssh/_params.py b/src/ssh/azext_ssh/_params.py index 9e9bf17d3ff..59718ccd9ee 100644 --- a/src/ssh/azext_ssh/_params.py +++ b/src/ssh/azext_ssh/_params.py @@ -32,6 +32,8 @@ def load_arguments(self, _): c.argument('ssh_proxy_folder', options_list=['--ssh-proxy-folder'], help=('Path to the folder where the ssh proxy should be saved. ' 'Default to .clientsshproxy folder in user\'s home directory if not provided.')) + c.argument('winrdp', options_list=['--winrdp', '--rdp'], help=('Start RDP connection over SSH.'), + action='store_true') c.positional('ssh_args', nargs='*', help='Additional arguments passed to OpenSSH') with self.argument_context('ssh config') as c: @@ -87,4 +89,6 @@ def load_arguments(self, _): c.argument('ssh_proxy_folder', options_list=['--ssh-proxy-folder'], help=('Path to the folder where the ssh proxy should be saved. ' 'Default to .clientsshproxy folder in user\'s home directory if not provided.')) + c.argument('winrdp', options_list=['--winrdp', '--rdp'], help=('Start RDP connection over SSH.'), + action='store_true') c.positional('ssh_args', nargs='*', help='Additional arguments passed to OpenSSH') diff --git a/src/ssh/azext_ssh/_process_helper.py b/src/ssh/azext_ssh/_process_helper.py new file mode 100644 index 00000000000..e8741d05216 --- /dev/null +++ b/src/ssh/azext_ssh/_process_helper.py @@ -0,0 +1,112 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +# pylint: disable=too-few-public-methods +# pylint: disable=consider-using-with + +import subprocess +from ctypes import WinDLL, c_int, c_size_t, Structure, WinError, sizeof, pointer +from ctypes.wintypes import BOOL, DWORD, HANDLE, LPVOID, LPCWSTR, LPDWORD +from knack.log import get_logger + +logger = get_logger(__name__) + + +def _errcheck(is_error_result=(lambda result: not result)): + def impl(result, func, args): + # pylint: disable=unused-argument + if is_error_result(result): + raise WinError() + + return result + + return impl + + +# Win32 CreateJobObject +kernel32 = WinDLL("kernel32") +kernel32.CreateJobObjectW.errcheck = _errcheck(lambda result: result == 0) +kernel32.CreateJobObjectW.argtypes = (LPVOID, LPCWSTR) +kernel32.CreateJobObjectW.restype = HANDLE + + +# Win32 OpenProcess +PROCESS_TERMINATE = 0x0001 +PROCESS_SET_QUOTA = 0x0100 +PROCESS_SYNCHRONIZE = 0x00100000 +kernel32.OpenProcess.errcheck = _errcheck(lambda result: result == 0) +kernel32.OpenProcess.restype = HANDLE +kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD) + +# Win32 WaitForSingleObject +INFINITE = 0xFFFFFFFF +# kernel32.WaitForSingleObject.errcheck = _errcheck() +kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD) +kernel32.WaitForSingleObject.restype = DWORD + +# Win32 AssignProcessToJobObject +kernel32.AssignProcessToJobObject.errcheck = _errcheck() +kernel32.AssignProcessToJobObject.argtypes = (HANDLE, HANDLE) +kernel32.AssignProcessToJobObject.restype = BOOL + +# Win32 QueryInformationJobObject +JOBOBJECTCLASS = c_int +JobObjectBasicProcessIdList = JOBOBJECTCLASS(3) + + +class JOBOBJECT_BASIC_PROCESS_ID_LIST(Structure): + _fields_ = [('NumberOfAssignedProcess', DWORD), + ('NumberOfProcessIdsInList', DWORD), + ('ProcessIdList', c_size_t * 1)] + + +kernel32.QueryInformationJobObject.errcheck = _errcheck() +kernel32.QueryInformationJobObject.restype = BOOL +kernel32.QueryInformationJobObject.argtypes = (HANDLE, JOBOBJECTCLASS, LPVOID, DWORD, LPDWORD) + + +def launch_and_wait(command): + """Windows Only: Runs and waits for the command to exit. It creates a new process and + associates it with a job object. It then waits for all the job object child processes + to exit. + """ + try: + job = kernel32.CreateJobObjectW(None, None) + process = subprocess.Popen(command) + + # Terminate and set quota are required to join process to job + process_handle = kernel32.OpenProcess( + PROCESS_TERMINATE | PROCESS_SET_QUOTA, + False, + process.pid, + ) + kernel32.AssignProcessToJobObject(job, process_handle) + + job_info = JOBOBJECT_BASIC_PROCESS_ID_LIST() + job_info_size = DWORD(sizeof(job_info)) + + while True: + kernel32.QueryInformationJobObject( + job, + JobObjectBasicProcessIdList, + pointer(job_info), + job_info_size, + pointer(job_info_size)) + + # Wait for the first running child under the job object + if job_info.NumberOfProcessIdsInList > 0: + logger.debug("Waiting for process %d", job_info.ProcessIdList[0]) + # Synchronize access is required to wait on handle + child_handle = kernel32.OpenProcess( + PROCESS_SYNCHRONIZE, + False, + job_info.ProcessIdList[0], + ) + kernel32.WaitForSingleObject(child_handle, INFINITE) + else: + break + + except OSError as e: + logger.error("Could not run '%s' command. Exception: %s", command, str(e)) diff --git a/src/ssh/azext_ssh/connectivity_utils.py b/src/ssh/azext_ssh/connectivity_utils.py index 0a625f220ed..fd38dab6281 100644 --- a/src/ssh/azext_ssh/connectivity_utils.py +++ b/src/ssh/azext_ssh/connectivity_utils.py @@ -12,9 +12,8 @@ from glob import glob import colorama -from colorama import Fore -from colorama import Style +from azure.cli.core.style import Style, print_styled_text from azure.core.exceptions import ResourceNotFoundError from azure.cli.core import telemetry from azure.cli.core import azclierror @@ -69,7 +68,8 @@ def _create_default_endpoint(cmd, resource_group, vm_name, client): colorama.init() raise azclierror.UnauthorizedError(f"Unable to create Default Endpoint for {vm_name} in {resource_group}." f"\nError: {str(e)}", - Fore.YELLOW + "Contact Owner/Contributor of the resource." + Style.RESET_ALL) + colorama.Fore.YELLOW + "Contact Owner/Contributor of the resource." + + colorama.Style.RESET_ALL) # Downloads client side proxy to connect to Arc Connectivity Platform @@ -109,8 +109,7 @@ def get_client_side_proxy(arc_proxy_folder): # write executable in the install location file_utils.write_to_file(install_location, 'wb', response_content, "Failed to create client proxy file. ") os.chmod(install_location, os.stat(install_location).st_mode | stat.S_IXUSR) - colorama.init() - print(Fore.GREEN + f"SSH Client Proxy saved to {install_location}" + Style.RESET_ALL) + print_styled_text((Style.SUCCESS, f"SSH Client Proxy saved to {install_location}")) return install_location diff --git a/src/ssh/azext_ssh/constants.py b/src/ssh/azext_ssh/constants.py index eaf98715102..4917411aeb6 100644 --- a/src/ssh/azext_ssh/constants.py +++ b/src/ssh/azext_ssh/constants.py @@ -17,3 +17,4 @@ "--ssh-client-folder to provide OpenSSH folder path." + Style.RESET_ALL) RECOMMENDATION_RESOURCE_NOT_FOUND = (Fore.YELLOW + "Please ensure the active subscription is set properly " "and resource exists." + Style.RESET_ALL) +RDP_TERMINATE_SSH_WAIT_TIME_IN_SECONDS = 30 diff --git a/src/ssh/azext_ssh/custom.py b/src/ssh/azext_ssh/custom.py index 58c2033f0f3..e396fcefcaa 100644 --- a/src/ssh/azext_ssh/custom.py +++ b/src/ssh/azext_ssh/custom.py @@ -12,15 +12,15 @@ import oschmod import colorama -from colorama import Fore -from colorama import Style from knack import log from azure.cli.core import azclierror from azure.cli.core import telemetry from azure.core.exceptions import ResourceNotFoundError, HttpResponseError +from azure.cli.core.style import Style, print_styled_text from . import ip_utils +from . import rdp_utils from . import rsa_parser from . import ssh_utils from . import connectivity_utils @@ -33,7 +33,8 @@ def ssh_vm(cmd, resource_group_name=None, vm_name=None, ssh_ip=None, public_key_file=None, private_key_file=None, use_private_ip=False, local_user=None, cert_file=None, port=None, - ssh_client_folder=None, delete_credentials=False, resource_type=None, ssh_proxy_folder=None, ssh_args=None): + ssh_client_folder=None, delete_credentials=False, resource_type=None, ssh_proxy_folder=None, + winrdp=False, ssh_args=None): # delete_credentials can only be used by Azure Portal to provide one-click experience on CloudShell. if delete_credentials and os.environ.get("AZUREPS_HOST_ENVIRONMENT") != "cloud-shell/1.0": @@ -49,10 +50,17 @@ def ssh_vm(cmd, resource_group_name=None, vm_name=None, ssh_ip=None, public_key_ credentials_folder = None op_call = ssh_utils.start_ssh_connection + if winrdp: + if platform.system() != 'Windows': + raise azclierror.BadRequestError("RDP connection is not supported for this platform. " + "Supported platforms: Windows") + logger.warning("RDP feature is in preview.") + op_call = rdp_utils.start_rdp_connection + ssh_session = ssh_info.SSHSession(resource_group_name, vm_name, ssh_ip, public_key_file, private_key_file, use_private_ip, local_user, cert_file, port, ssh_client_folder, ssh_args, delete_credentials, resource_type, - ssh_proxy_folder, credentials_folder) + ssh_proxy_folder, credentials_folder, winrdp) ssh_session.resource_type = _decide_resource_type(cmd, ssh_session) _do_ssh_op(cmd, ssh_session, op_call) @@ -120,24 +128,22 @@ def ssh_cert(cmd, cert_path=None, public_key_file=None, ssh_client_folder=None): if keys_folder: logger.warning("%s contains sensitive information (id_rsa, id_rsa.pub). " "Please delete once this certificate is no longer being used.", keys_folder) - - colorama.init() # pylint: disable=broad-except try: cert_expiration = ssh_utils.get_certificate_start_and_end_times(cert_file, ssh_client_folder)[1] - print(Fore.GREEN + f"Generated SSH certificate {cert_file} is valid until {cert_expiration} in local time." - + Style.RESET_ALL) + print_styled_text((Style.SUCCESS, + f"Generated SSH certificate {cert_file} is valid until {cert_expiration} in local time.")) except Exception as e: logger.warning("Couldn't determine certificate validity. Error: %s", str(e)) - print(Fore.GREEN + f"Generated SSH certificate {cert_file}." + Style.RESET_ALL) + print_styled_text((Style.SUCCESS, f"Generated SSH certificate {cert_file}.")) def ssh_arc(cmd, resource_group_name=None, vm_name=None, public_key_file=None, private_key_file=None, local_user=None, cert_file=None, port=None, ssh_client_folder=None, delete_credentials=False, - ssh_proxy_folder=None, ssh_args=None): + ssh_proxy_folder=None, winrdp=False, ssh_args=None): ssh_vm(cmd, resource_group_name, vm_name, None, public_key_file, private_key_file, False, local_user, cert_file, - port, ssh_client_folder, delete_credentials, "Microsoft.HybridCompute", ssh_proxy_folder, ssh_args) + port, ssh_client_folder, delete_credentials, "Microsoft.HybridCompute", ssh_proxy_folder, winrdp, ssh_args) def _do_ssh_op(cmd, op_info, op_call): @@ -389,7 +395,8 @@ def _decide_resource_type(cmd, op_info): colorama.init() raise azclierror.BadRequestError(f"{op_info.resource_group_name} has Azure VM and Arc Server with the " f"same name: {op_info.vm_name}.", - Fore.YELLOW + "Please provide a --resource-type." + Style.RESET_ALL) + colorama.Fore.YELLOW + "Please provide a --resource-type." + + colorama.Style.RESET_ALL) if not is_azure_vm and not is_arc_server: colorama.init() if isinstance(arc_error, ResourceNotFoundError) and isinstance(vm_error, ResourceNotFoundError): @@ -416,7 +423,8 @@ def _decide_resource_type(cmd, op_info): colorama.init() raise azclierror.RequiredArgumentMissingError("SSH Login using AAD credentials is not currently supported " "for Windows.", - Fore.YELLOW + "Please provide --local-user." + Style.RESET_ALL) + colorama.Fore.YELLOW + "Please provide --local-user." + + colorama.Style.RESET_ALL) target_resource_type = "Microsoft.Compute" if is_arc_server: diff --git a/src/ssh/azext_ssh/rdp_utils.py b/src/ssh/azext_ssh/rdp_utils.py new file mode 100644 index 00000000000..63a6ff3c460 --- /dev/null +++ b/src/ssh/azext_ssh/rdp_utils.py @@ -0,0 +1,197 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import os +import platform +import subprocess +import time +import psutil + +from knack import log +from azure.cli.core import azclierror +from azure.cli.core import telemetry +from azure.cli.core.style import Style, print_styled_text + +from . import ssh_utils +from . import connectivity_utils +from . import constants as const +from . import file_utils + +logger = log.get_logger(__name__) + + +def start_rdp_connection(ssh_info, delete_keys, delete_cert): + try: + ssh_process = None + log_list = [] + print_ssh_logs = False + ssh_success = False + + resource_port = 3389 + local_port = _get_open_port() + + while not is_local_port_open(local_port): + local_port = _get_open_port() + + if ssh_info.ssh_args is None: + ssh_info.ssh_args = ['-L', f"{local_port}:localhost:{resource_port}", "-N"] + else: + ssh_info.ssh_args = ['-L', f"{local_port}:localhost:{resource_port}", "-N"] + ssh_info.ssh_args + + ssh_process, print_ssh_logs = start_ssh_tunnel(ssh_info) + ssh_connection_t0 = time.time() + ssh_success, log_list = wait_for_ssh_connection(ssh_process, print_ssh_logs) + ssh_utils.do_cleanup(delete_keys, delete_cert, ssh_info.cert_file, ssh_info.private_key_file, + ssh_info.public_key_file) + if ssh_success and ssh_process.poll() is None: + call_rdp(local_port) + + finally: + if ssh_success: + ssh_connection_data = {'Context.Default.AzureCLI.SSHConnectionDurationInMinutes': + (time.time() - ssh_connection_t0) / 60} + ssh_connection_data['Context.Default.AzureCLI.SSHConnectionStatus'] = "Success" + telemetry.add_extension_event('ssh', ssh_connection_data) + + terminate_ssh(ssh_process, log_list, print_ssh_logs) + ssh_utils.do_cleanup(delete_keys, delete_cert, ssh_info.cert_file, ssh_info.private_key_file, + ssh_info.public_key_file) + if delete_keys: + # This is only true if keys were generated, so they must be in a temp folder. + temp_dir = os.path.dirname(ssh_info.cert_file) + file_utils.delete_folder(temp_dir, f"Couldn't delete temporary folder {temp_dir}", True) + + +def call_rdp(local_port): + from . import _process_helper + if platform.system() == 'Windows': + print_styled_text((Style.SUCCESS, "Launching Remote Desktop Connection")) + print_styled_text((Style.IMPORTANT, "To close this session, close the Remote Desktop Connection window.")) + command = [_get_rdp_path(), f"/v:localhost:{local_port}"] + _process_helper.launch_and_wait(command) + + +def is_local_port_open(local_port): + import socket + from contextlib import closing + is_port_open = False + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: + if sock.connect_ex(('', local_port)) == 0: + logger.info('Port %s is NOT open', local_port) + else: + logger.info('Port %s is open', local_port) + is_port_open = True + return is_port_open + + +def _get_open_port(): + import socket + from contextlib import closing + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(('', 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] + + +def start_ssh_tunnel(op_info): + # pylint: disable=consider-using-with + env = os.environ.copy() + if op_info.is_arc(): + env['SSHPROXY_RELAY_INFO'] = connectivity_utils.format_relay_info_string(op_info.relay_info) + + print_ssh_logs = False + if not set(['-v', '-vv', '-vvv']).isdisjoint(op_info.ssh_args): + print_ssh_logs = True + else: + op_info.ssh_args = ['-v'] + op_info.ssh_args + + if '-E' in op_info.ssh_args: + raise azclierror.BadRequestError("Can't use -E ssh parameter when using --rdp") + + command = [ssh_utils.get_ssh_client_path('ssh', op_info.ssh_client_folder), op_info.get_host()] + command = command + op_info.build_args() + op_info.ssh_args + + logger.debug("Running ssh command %s", ' '.join(command)) + ssh_sub = subprocess.Popen(command, shell=True, stderr=subprocess.PIPE, env=env, encoding='utf-8') + return ssh_sub, print_ssh_logs + + +def wait_for_ssh_connection(ssh_sub, print_ssh_logs): + log_list = [] + ssh_sucess = False + while True: + next_line = ssh_sub.stderr.readline() + if print_ssh_logs: + print(next_line, end='') + else: + log_list.append(next_line) + if "debug1: Entering interactive session." in next_line: + logger.debug("SSH Connection estalished succesfully.") + ssh_sucess = True + break + if ssh_sub.poll() is not None: + logger.debug("SSH Connection failed.") + ssh_sucess = False + break + return ssh_sucess, log_list + + +def terminate_ssh(ssh_process, log_list, print_ssh_logs): + if ssh_process: + terminated = False + if ssh_process.poll() is None: + kill_process(ssh_process.pid) + t0 = time.time() + while ssh_process.poll() is None and (time.time() - t0) < const.RDP_TERMINATE_SSH_WAIT_TIME_IN_SECONDS: + time.sleep(1) + terminated = True + print_error_messages_from_log(log_list, print_ssh_logs, ssh_process, terminated) + + +def print_error_messages_from_log(log_list, print_ssh_logs, ssh_process, terminated): + # Read the remaining log messages since the connection was established. + next_line = ssh_process.stderr.readline() + while next_line: + if print_ssh_logs: + print(next_line, end='') + else: + log_list.append(next_line) + next_line = ssh_process.stderr.readline() + + # If ssh process was not forced to terminate, print potential error messages. + if ssh_process.returncode != 0 and not print_ssh_logs and not terminated: + for line in log_list: + if "debug1:" not in line and line != '': + print(str(line), end='') + + +def _get_rdp_path(rdp_command="mstsc"): + rdp_path = rdp_command + if platform.system() == 'Windows': + arch_data = platform.architecture() + sys_path = 'System32' + system_root = os.environ['SystemRoot'] + system32_path = os.path.join(system_root, sys_path) + rdp_path = os.path.join(system32_path, (rdp_command + ".exe")) + logger.debug("Platform architecture: %s", str(arch_data)) + logger.debug("System Root: %s", system_root) + logger.debug("Attempting to run rdp from path %s", rdp_path) + + if not os.path.isfile(rdp_path): + raise azclierror.BadRequestError("Could not find " + rdp_command + ".exe. Is the rdp client installed?") + else: + raise azclierror.BadRequestError("Platform is not supported for this command. Supported platforms: Windows") + + return rdp_path + + +def kill_process(pid): + try: + process = psutil.Process(pid) + for proc in process.children(recursive=True): + proc.kill() + process.kill() + except psutil.NoSuchProcess as e: + logger.warning("Kill process failed. Process no longer exists: %s", str(e)) diff --git a/src/ssh/azext_ssh/ssh_info.py b/src/ssh/azext_ssh/ssh_info.py index 5e99459e924..37672abc05e 100644 --- a/src/ssh/azext_ssh/ssh_info.py +++ b/src/ssh/azext_ssh/ssh_info.py @@ -6,9 +6,7 @@ import datetime import oschmod -import colorama -from colorama import Fore -from colorama import Style +from azure.cli.core.style import Style, print_styled_text from azure.cli.core import azclierror from knack import log @@ -22,7 +20,7 @@ class SSHSession(): # pylint: disable=too-many-instance-attributes def __init__(self, resource_group_name, vm_name, ssh_ip, public_key_file, private_key_file, use_private_ip, local_user, cert_file, port, ssh_client_folder, ssh_args, - delete_credentials, resource_type, ssh_proxy_folder, credentials_folder): + delete_credentials, resource_type, ssh_proxy_folder, credentials_folder, winrdp): self.resource_group_name = resource_group_name self.vm_name = vm_name self.ip = ssh_ip @@ -32,6 +30,7 @@ def __init__(self, resource_group_name, vm_name, ssh_ip, public_key_file, privat self.ssh_args = ssh_args self.delete_credentials = delete_credentials self.resource_type = resource_type + self.winrdp = winrdp self.proxy_path = None self.relay_info = None self.public_key_file = os.path.abspath(public_key_file) if public_key_file else None @@ -189,9 +188,8 @@ def _create_relay_info_file(self): try: expiration = datetime.datetime.fromtimestamp(self.relay_info.expires_on) expiration = expiration.strftime("%Y-%m-%d %I:%M:%S %p") - colorama.init() - print(Fore.GREEN + f"Generated relay information {relay_info_path} is valid until {expiration} " - "in local time." + Style.RESET_ALL) + print_styled_text((Style.SUCCESS, f"Generated relay information {relay_info_path} is valid until " + f"{expiration} in local time.")) except Exception as e: logger.warning("Couldn't determine relay information expiration. Error: %s", str(e)) diff --git a/src/ssh/azext_ssh/ssh_utils.py b/src/ssh/azext_ssh/ssh_utils.py index 70eb16de634..1986a604431 100644 --- a/src/ssh/azext_ssh/ssh_utils.py +++ b/src/ssh/azext_ssh/ssh_utils.py @@ -11,12 +11,10 @@ import re import colorama -from colorama import Fore -from colorama import Style - from knack import log from azure.cli.core import azclierror from azure.cli.core import telemetry +from azure.cli.core.style import Style, print_styled_text from . import file_utils from . import connectivity_utils @@ -42,7 +40,7 @@ def start_ssh_connection(op_info, delete_keys, delete_cert): env['SSHPROXY_RELAY_INFO'] = connectivity_utils.format_relay_info_string(op_info.relay_info) # Get ssh client before starting the clean up process in case there is an error in getting client. - command = [_get_ssh_client_path('ssh', op_info.ssh_client_folder), op_info.get_host()] + command = [get_ssh_client_path('ssh', op_info.ssh_client_folder), op_info.get_host()] if not op_info.cert_file and not op_info.private_key_file: # In this case, even if delete_credentials is true, there is nothing to clean-up. @@ -97,7 +95,7 @@ def write_ssh_config(config_info, delete_keys, delete_cert): def create_ssh_keyfile(private_key_file, ssh_client_folder=None): - sshkeygen_path = _get_ssh_client_path("ssh-keygen", ssh_client_folder) + sshkeygen_path = get_ssh_client_path("ssh-keygen", ssh_client_folder) command = [sshkeygen_path, "-f", private_key_file, "-t", "rsa", "-q", "-N", ""] logger.debug("Running ssh-keygen command %s", ' '.join(command)) try: @@ -109,7 +107,7 @@ def create_ssh_keyfile(private_key_file, ssh_client_folder=None): def get_ssh_cert_info(cert_file, ssh_client_folder=None): - sshkeygen_path = _get_ssh_client_path("ssh-keygen", ssh_client_folder) + sshkeygen_path = get_ssh_client_path("ssh-keygen", ssh_client_folder) command = [sshkeygen_path, "-L", "-f", cert_file] logger.debug("Running ssh-keygen command %s", ' '.join(command)) try: @@ -211,7 +209,7 @@ def _print_error_messages_from_ssh_log(log_file, connection_status, delete_cert) ssh_log.close() -def _get_ssh_client_path(ssh_command="ssh", ssh_client_folder=None): +def get_ssh_client_path(ssh_command="ssh", ssh_client_folder=None): if ssh_client_folder: ssh_path = os.path.join(ssh_client_folder, ssh_command) if platform.system() == 'Windows': @@ -258,9 +256,9 @@ def _get_ssh_client_path(ssh_command="ssh", ssh_client_folder=None): if not os.path.isfile(ssh_path): raise azclierror.UnclassifiedUserFault( "Could not find " + ssh_command + ".exe on path " + ssh_path + ". ", - Fore.YELLOW + "Make sure OpenSSH is installed correctly: " + colorama.Fore.YELLOW + "Make sure OpenSSH is installed correctly: " "https://docs.microsoft.com/en-us/windows-server/administration/openssh/openssh_install_firstuse . " - "Or use --ssh-client-folder to provide folder path with ssh executables. " + Style.RESET_ALL) + "Or use --ssh-client-folder to provide folder path with ssh executables. " + colorama.Style.RESET_ALL) return ssh_path @@ -359,13 +357,12 @@ def _terminate_cleanup(delete_keys, delete_cert, delete_credentials, cleanup_pro def _issue_config_cleanup_warning(delete_cert, delete_keys, is_arc, cert_file, relay_info_path, ssh_client_folder): if delete_cert: - colorama.init() # pylint: disable=broad-except try: expiration = get_certificate_start_and_end_times(cert_file, ssh_client_folder)[1] expiration = expiration.strftime("%Y-%m-%d %I:%M:%S %p") - print(Fore.GREEN + f"Generated SSH certificate {cert_file} is valid until {expiration} in local time." - + Style.RESET_ALL) + print_styled_text((Style.SUCCESS, + f"Generated SSH certificate {cert_file} is valid until {expiration} in local time.")) except Exception as e: logger.warning("Couldn't determine certificate expiration. Error: %s", str(e)) diff --git a/src/ssh/azext_ssh/tests/latest/test_custom.py b/src/ssh/azext_ssh/tests/latest/test_custom.py index 51212fac195..69e1b6f3c2c 100644 --- a/src/ssh/azext_ssh/tests/latest/test_custom.py +++ b/src/ssh/azext_ssh/tests/latest/test_custom.py @@ -6,6 +6,7 @@ import unittest from unittest import mock from azext_ssh import custom +from azext_ssh import rdp_utils from azext_ssh import ssh_utils from azext_ssh import ssh_info @@ -26,12 +27,32 @@ def test_ssh_vm(self, mock_type, mock_info, mock_assert, mock_do_op): cmd.cli_ctx.data = {'safe_params': []} - custom.ssh_vm(cmd, "rg", "vm", "ip", "public", "private", False, "username", "cert", "port", "ssh_folder", False, "type", "proxy", ['-vvv']) + custom.ssh_vm(cmd, "rg", "vm", "ip", "public", "private", False, "username", "cert", "port", "ssh_folder", False, "type", "proxy", False, ['-vvv']) - mock_info.assert_called_once_with("rg", "vm", "ip", "public", "private", False, "username", "cert", "port", "ssh_folder", ['-vvv'], False, "type", "proxy", None) + mock_info.assert_called_once_with("rg", "vm", "ip", "public", "private", False, "username", "cert", "port", "ssh_folder", ['-vvv'], False, "type", "proxy", None, False) mock_assert.assert_called_once_with("rg", "vm", "ip", "type", "cert", "username") mock_type.assert_called_once_with(cmd, ssh_info) mock_do_op.assert_called_once_with(cmd, ssh_info, ssh_utils.start_ssh_connection) + + @mock.patch('azext_ssh.custom._do_ssh_op') + @mock.patch('azext_ssh.custom._assert_args') + @mock.patch('azext_ssh.ssh_info.SSHSession') + @mock.patch('azext_ssh.custom._decide_resource_type') + @mock.patch('platform.system') + def test_ssh_vm_rdp(self, mock_sys, mock_type, mock_info, mock_assert, mock_do_op): + cmd = mock.Mock() + ssh_info = mock.Mock() + mock_info.return_value = ssh_info + mock_sys.return_value = 'Windows' + + cmd.cli_ctx.data = {'safe_params': []} + + custom.ssh_vm(cmd, "rg", "vm", "ip", "public", "private", False, "username", "cert", "port", "ssh_folder", False, "type", "proxy", True, ['-vvv']) + + mock_info.assert_called_once_with("rg", "vm", "ip", "public", "private", False, "username", "cert", "port", "ssh_folder", ['-vvv'], False, "type", "proxy", None, True) + mock_assert.assert_called_once_with("rg", "vm", "ip", "type", "cert", "username") + mock_type.assert_called_once_with(cmd, ssh_info) + mock_do_op.assert_called_once_with(cmd, ssh_info, rdp_utils.start_rdp_connection) @mock.patch('azext_ssh.custom._do_ssh_op') @mock.patch('azext_ssh.custom._assert_args') @@ -44,9 +65,9 @@ def test_ssh_vm_debug(self, mock_type, mock_info, mock_assert, mock_do_op): cmd.cli_ctx.data = {'safe_params': ['--debug']} - custom.ssh_vm(cmd, "rg", "vm", "ip", "public", "private", False, "username", "cert", "port", "ssh_folder", False, "type", "proxy", []) + custom.ssh_vm(cmd, "rg", "vm", "ip", "public", "private", False, "username", "cert", "port", "ssh_folder", False, "type", "proxy", False, []) - mock_info.assert_called_once_with("rg", "vm", "ip", "public", "private", False, "username", "cert", "port", "ssh_folder", ['-vvv'], False, "type", "proxy", None) + mock_info.assert_called_once_with("rg", "vm", "ip", "public", "private", False, "username", "cert", "port", "ssh_folder", ['-vvv'], False, "type", "proxy", None, False) mock_assert.assert_called_once_with("rg", "vm", "ip", "type", "cert", "username") mock_type.assert_called_once_with(cmd, ssh_info) mock_do_op.assert_called_once_with(cmd, ssh_info, ssh_utils.start_ssh_connection) @@ -63,9 +84,9 @@ def test_ssh_vm_delete_credentials_cloudshell(self, mock_info, mock_assert, mock cmd.cli_ctx.data = {'safe_params': []} mock_info.return_value = ssh_info - custom.ssh_vm(cmd, "rg", "vm", "ip", "public", "private", False, "username", "cert", "port", "ssh_folder", True, "type", "proxy", []) + custom.ssh_vm(cmd, "rg", "vm", "ip", "public", "private", False, "username", "cert", "port", "ssh_folder", True, "type", "proxy", False, []) - mock_info.assert_called_once_with("rg", "vm", "ip", "public", "private", False, "username", "cert", "port", "ssh_folder", [], True, "type", "proxy", None) + mock_info.assert_called_once_with("rg", "vm", "ip", "public", "private", False, "username", "cert", "port", "ssh_folder", [], True, "type", "proxy", None, False) mock_assert.assert_called_once_with("rg", "vm", "ip", "type", "cert", "username") mock_type.assert_called_once_with(cmd, ssh_info) mock_op.assert_called_once_with(cmd, ssh_info, ssh_utils.start_ssh_connection) @@ -75,7 +96,7 @@ def test_delete_credentials_not_cloudshell(self, mock_getenv): mock_getenv.return_value = None cmd = mock.Mock() self.assertRaises( - azclierror.ArgumentUsageError, custom.ssh_vm, cmd, 'rg', 'vm', 'ip', 'pub', 'priv', False, 'user', 'cert', 'port', 'client', True, 'type', 'proxy', []) + azclierror.ArgumentUsageError, custom.ssh_vm, cmd, 'rg', 'vm', 'ip', 'pub', 'priv', False, 'user', 'cert', 'port', 'client', True, 'type', 'proxy', False, []) @mock.patch('azext_ssh.custom._assert_args') @mock.patch('azext_ssh.custom._do_ssh_op') @@ -141,9 +162,9 @@ def test_ssh_config_credentials_folder_and_key(self): @mock.patch('azext_ssh.custom.ssh_vm') def test_ssh_arc(self, mock_vm): cmd = mock.Mock() - custom.ssh_arc(cmd, "rg", "vm", "pub", "priv", "user", "cert", "port", "client", False, "proxy", []) + custom.ssh_arc(cmd, "rg", "vm", "pub", "priv", "user", "cert", "port", "client", False, "proxy", False, []) - mock_vm.assert_called_once_with(cmd, "rg", "vm", None, "pub", "priv", False, "user", "cert", "port", "client", False, "Microsoft.HybridCompute", "proxy", []) + mock_vm.assert_called_once_with(cmd, "rg", "vm", None, "pub", "priv", False, "user", "cert", "port", "client", False, "Microsoft.HybridCompute", "proxy", False, []) def test_ssh_cert_no_args(self): cmd = mock.Mock() @@ -339,7 +360,7 @@ def test_do_ssh_op_aad_user_compute(self, mock_write_cert, mock_ssh_creds, mock_ cmd.cli_ctx.cloud = mock.Mock() cmd.cli_ctx.cloud.name = "azurecloud" - op_info = ssh_info.SSHSession(None, None, "1.2.3.4", None, None, False, None, None, None, None, None, None, "Microsoft.Compute", None, None) + op_info = ssh_info.SSHSession(None, None, "1.2.3.4", None, None, False, None, None, None, None, None, None, "Microsoft.Compute", None, None, False) op_info.public_key_file = "publicfile" op_info.private_key_file = "privatefile" op_info.ssh_client_folder = "/client/folder" @@ -387,7 +408,7 @@ def test_do_ssh_op_no_public_ip(self, mock_ip, mock_check_files): mock_op = mock.Mock() mock_ip.return_value = None - op_info = ssh_info.SSHSession("rg", "vm", None, None, None, False, None, None, None, None, None, None, "Microsoft.Compute", None, None) + op_info = ssh_info.SSHSession("rg", "vm", None, None, None, False, None, None, None, None, None, None, "Microsoft.Compute", None, None, False) self.assertRaises( azclierror.ResourceNotFoundError, custom._do_ssh_op, cmd, op_info, mock_op) @@ -405,7 +426,7 @@ def test_do_ssh_op_arc_local_user(self, mock_get_cert, mock_check_keys, mock_sta cmd = mock.Mock() mock_op = mock.Mock() - op_info = ssh_info.SSHSession("rg", "vm", None, None, None, False, "user", None, "port", None, [], False, "Microsoft.HybridCompute", None, None) + op_info = ssh_info.SSHSession("rg", "vm", None, None, None, False, "user", None, "port", None, [], False, "Microsoft.HybridCompute", None, None, False) op_info.private_key_file = "priv" op_info.cert_file = "cert" op_info.ssh_client_folder = "client" @@ -450,7 +471,7 @@ def test_do_ssh_arc_op_aad_user(self, mock_cert_exp, mock_start_ssh, mock_write_ mock_op = mock.Mock() - op_info = ssh_info.SSHSession("rg", "vm", None, None, None, False, None, None, "port", None, [], False, "Microsoft.HybridCompute", None, None) + op_info = ssh_info.SSHSession("rg", "vm", None, None, None, False, None, None, "port", None, [], False, "Microsoft.HybridCompute", None, None, False) op_info.public_key_file = "publicfile" op_info.private_key_file = "privatefile" op_info.ssh_client_folder = "client" @@ -470,21 +491,21 @@ def test_do_ssh_arc_op_aad_user(self, mock_cert_exp, mock_start_ssh, mock_write_ def test_decide_resource_type_ip(self): cmd = mock.Mock() - op_info = ssh_info.SSHSession(None, None, "ip", None, None, False, None, None, None, None, [], False, None, None, None) + op_info = ssh_info.SSHSession(None, None, "ip", None, None, False, None, None, None, None, [], False, None, None, None, False) self.assertEqual(custom._decide_resource_type(cmd, op_info), "Microsoft.Compute") @mock.patch('azext_ssh.custom._check_if_arc_server') def test_decide_resource_type_resourcetype_arc(self, mock_is_arc): cmd = mock.Mock() mock_is_arc.return_value = None, None, True - op_info = ssh_info.SSHSession("rg", "vm", None, None, None, False, None, None, None, None, [], False, "Microsoft.HybridCompute", None, None) + op_info = ssh_info.SSHSession("rg", "vm", None, None, None, False, None, None, None, None, [], False, "Microsoft.HybridCompute", None, None, False) self.assertEqual(custom._decide_resource_type(cmd, op_info), "Microsoft.HybridCompute") @mock.patch('azext_ssh.custom._check_if_azure_vm') def test_decide_resource_type_resourcetype_arc(self, mock_is_vm): cmd = mock.Mock() mock_is_vm.return_value = None, None, True - op_info = ssh_info.SSHSession("rg", "vm", None, None, None, False, None, None, None, None, [], False, "Microsoft.Compute", None, None) + op_info = ssh_info.SSHSession("rg", "vm", None, None, None, False, None, None, None, None, [], False, "Microsoft.Compute", None, None, False) self.assertEqual(custom._decide_resource_type(cmd, op_info), "Microsoft.Compute") @mock.patch('azext_ssh.custom._check_if_azure_vm') @@ -493,7 +514,7 @@ def test_decide_resource_type_rg_vm_both(self, mock_is_arc, mock_is_vm): cmd = mock.Mock() mock_is_vm.return_value = None, None, True mock_is_arc.return_value = None, None, True - op_info = ssh_info.SSHSession("rg", "vm", None, None, None, False, None, None, None, None, [], False, None, None, None) + op_info = ssh_info.SSHSession("rg", "vm", None, None, None, False, None, None, None, None, [], False, None, None, None, False) self.assertRaises( azclierror.BadRequestError, custom._decide_resource_type, cmd, op_info) @@ -503,7 +524,7 @@ def test_decide_resource_type_rg_vm_neither(self, mock_is_arc, mock_is_vm): cmd = mock.Mock() mock_is_vm.return_value = None, ResourceNotFoundError(), False mock_is_arc.return_value = None, ResourceNotFoundError(), False - op_info = ssh_info.SSHSession("rg", "vm", None, None, None, False, None, None, None, None, [], False, None, None, None) + op_info = ssh_info.SSHSession("rg", "vm", None, None, None, False, None, None, None, None, [], False, None, None, None, False) self.assertRaises( azclierror.ResourceNotFoundError, custom._decide_resource_type, cmd, op_info) @@ -513,7 +534,7 @@ def test_decide_resource_type_rg_vm_arc(self, mock_is_arc, mock_is_vm): cmd = mock.Mock() mock_is_vm.return_value = None, ResourceNotFoundError(), False mock_is_arc.return_value = None, None, True - op_info = ssh_info.SSHSession("rg", "vm", None, None, None, False, None, None, None, None, [], False, None, None, None) + op_info = ssh_info.SSHSession("rg", "vm", None, None, None, False, None, None, None, None, [], False, None, None, None, False) self.assertEqual(custom._decide_resource_type(cmd, op_info), "Microsoft.HybridCompute") @mock.patch('azext_ssh.custom._check_if_azure_vm') @@ -522,7 +543,7 @@ def test_decide_resource_type_rg_vm_arc(self, mock_is_arc, mock_is_vm): cmd = mock.Mock() mock_is_vm.return_value = None, None, True mock_is_arc.return_value = None, ResourceNotFoundError(), False - op_info = ssh_info.SSHSession("rg", "vm", None, None, None, False, None, None, None, None, [], False, None, None, None) + op_info = ssh_info.SSHSession("rg", "vm", None, None, None, False, None, None, None, None, [], False, None, None, None, False) self.assertEqual(custom._decide_resource_type(cmd, op_info), "Microsoft.Compute") diff --git a/src/ssh/azext_ssh/tests/latest/test_rdp_utils.py b/src/ssh/azext_ssh/tests/latest/test_rdp_utils.py new file mode 100644 index 00000000000..74489140d52 --- /dev/null +++ b/src/ssh/azext_ssh/tests/latest/test_rdp_utils.py @@ -0,0 +1,97 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import unittest +from unittest import mock + +from azext_ssh import rdp_utils +from azext_ssh import ssh_info +from azext_ssh import ssh_utils + +class RDPUtilsTest(unittest.TestCase): + @mock.patch('os.environ.copy') + @mock.patch.object(ssh_utils, 'get_ssh_client_path') + @mock.patch('azext_ssh.custom.connectivity_utils.format_relay_info_string') + @mock.patch("subprocess.Popen") + def test_start_ssh_tunnel(self, mock_popen, mock_relay, mock_path, mock_env): + op_info = ssh_info.SSHSession("rg", "vm", None, None, None, False, "user", None, "port", None, ['arg1', 'arg2', '-v'], False, "Microsoft.HybridCompute", None, None, True) + op_info.public_key_file = "pub" + op_info.private_key_file = "priv" + op_info.cert_file = "cert" + op_info.ssh_client_folder = "client" + op_info.proxy_path = "proxy" + op_info.relay_info = "relay" + + mock_env.return_value = {'var1':'value1', 'var2':'value2', 'var3':'value3'} + mock_path.return_value = 'ssh' + mock_relay.return_value = 'relay_string' + mock_popen.return_value = 'ssh_process' + + expected_command = ['ssh', "user@vm", '-o', 'ProxyCommand=\"proxy\" -p port', '-i', 'priv', '-o', 'CertificateFile=\"cert\"', 'arg1', 'arg2', '-v'] + expected_env = {'var1':'value1', 'var2':'value2', 'var3':'value3', 'SSHPROXY_RELAY_INFO':'relay_string'} + + ssh_sub, print_logs = rdp_utils.start_ssh_tunnel(op_info) + + self.assertEqual(ssh_sub, 'ssh_process') + self.assertEqual(print_logs, True) + mock_popen.assert_called_once_with(expected_command, shell=True, stderr=mock.ANY, env=expected_env, encoding='utf-8') + mock_relay.assert_called_once_with("relay") + mock_path.assert_called_once_with('ssh', 'client') + + @mock.patch('platform.system') + @mock.patch('platform.architecture') + @mock.patch('os.environ') + @mock.patch('os.path.join') + @mock.patch('os.path.isfile') + def test_get_rdp_path(self, mock_isfile, mock_join, mock_env, mock_arch, mock_sys): + mock_env.__getitem__.return_value = "root" + mock_join.side_effect = ['root/sys', 'root/sys/rdp'] + mock_arch.return_value = '32bit' + mock_sys.return_value = 'Windows' + mock_isfile.return_value = True + + expected_join_calls = [ + mock.call("root", 'System32'), + mock.call("root/sys", "mstsc.exe") + ] + + rdp_utils._get_rdp_path() + + mock_join.assert_has_calls(expected_join_calls) + mock_isfile.assert_called_once_with('root/sys/rdp') + + #start rdp connection + @mock.patch.object(rdp_utils, '_get_open_port') + @mock.patch.object(rdp_utils, 'is_local_port_open') + @mock.patch.object(rdp_utils, 'start_ssh_tunnel') + @mock.patch.object(rdp_utils, 'wait_for_ssh_connection') + @mock.patch.object(rdp_utils, 'call_rdp') + @mock.patch.object(rdp_utils, 'terminate_ssh') + def test_start_rdp_connection(self, mock_terminate, mock_rdp, mock_wait, mock_tunnel, mock_isopen, mock_getport): + op_info = ssh_info.SSHSession("rg", "vm", None, None, None, False, "user", None, "port", None, ['arg1', 'arg2'], False, "Microsoft.HybridCompute", None, None, True) + op_info.public_key_file = "pub" + op_info.private_key_file = "priv" + op_info.cert_file = "cert" + op_info.ssh_client_folder = "client" + op_info.proxy_path = "proxy" + op_info.relay_info = "relay" + + mock_getport.return_value = 1020 + mock_isopen.return_value = True + ssh_pro = mock.Mock() + #ssh_pro.return_value.poll.return_value = None + mock_tunnel.return_value = ssh_pro, False + mock_wait.return_value = True, [] + + rdp_utils.start_rdp_connection(op_info, True, True) + + mock_terminate.assert_called_once_with(ssh_pro, [], False) + #mock_rdp.assert_called_once_with(1020) + mock_tunnel.assert_called_once_with(op_info) + mock_wait.assert_called_once_with(ssh_pro, False) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/src/ssh/azext_ssh/tests/latest/test_ssh_info.py b/src/ssh/azext_ssh/tests/latest/test_ssh_info.py index 40a231d5088..517b554f1d3 100644 --- a/src/ssh/azext_ssh/tests/latest/test_ssh_info.py +++ b/src/ssh/azext_ssh/tests/latest/test_ssh_info.py @@ -21,7 +21,7 @@ def test_ssh_session(self, mock_abspath): mock.call("proxy/path"), mock.call("cred/path") ] - session = ssh_info.SSHSession("rg", "vm", "ip", "pub", "priv", False, "user", "cert", "port", "client/folder", ['-v', '-E', 'path'], False, 'arc', 'proxy/path', 'cred/path') + session = ssh_info.SSHSession("rg", "vm", "ip", "pub", "priv", False, "user", "cert", "port", "client/folder", ['-v', '-E', 'path'], False, 'arc', 'proxy/path', 'cred/path', True) mock_abspath.assert_has_calls(expected_abspath_calls) self.assertEqual(session.resource_group_name, "rg") self.assertEqual(session.vm_name, "vm") @@ -40,24 +40,25 @@ def test_ssh_session(self, mock_abspath): self.assertEqual(session.resource_type, "arc") self.assertEqual(session.proxy_path, None) self.assertEqual(session.delete_credentials, False) + self.assertEqual(session.winrdp, True) def test_ssh_session_get_host(self): - session = ssh_info.SSHSession(None, None, "ip", None, None, False, "user", None, None, None, [], False, "Microsoft.Compute", None, None) + session = ssh_info.SSHSession(None, None, "ip", None, None, False, "user", None, None, None, [], False, "Microsoft.Compute", None, None, False) self.assertEqual("user@ip", session.get_host()) - session = ssh_info.SSHSession("rg", "vm", None, None, None, False, "user", None, None, None, [], False, "Microsoft.HybridCompute", None, None) + session = ssh_info.SSHSession("rg", "vm", None, None, None, False, "user", None, None, None, [], False, "Microsoft.HybridCompute", None, None, True) self.assertEqual("user@vm", session.get_host()) @mock.patch('os.path.abspath') def test_ssh_session_build_args_compute(self, mock_abspath): mock_abspath.side_effect = ["pub_path", "priv_path", "cert_path", "client_path"] - session = ssh_info.SSHSession("rg", "vm", "ip", "pub", "priv", False, "user", "cert", "port", "client/folder", [], None, "Microsoft.Compute", None, None) + session = ssh_info.SSHSession("rg", "vm", "ip", "pub", "priv", False, "user", "cert", "port", "client/folder", [], None, "Microsoft.Compute", None, None, False) self.assertEqual(["-i", "priv_path", "-o", "CertificateFile=\"cert_path\"", "-p", "port"], session.build_args()) @mock.patch('os.path.abspath') def test_ssh_session_build_args_hyvridcompute(self, mock_abspath): mock_abspath.side_effect = ["pub_path", "priv_path", "cert_path", "client_path"] - session = ssh_info.SSHSession("rg", "vm", "ip", "pub", "priv", False, "user", "cert", "port", "client/folder", [], None, "Microsoft.HybridCompute", None, None) + session = ssh_info.SSHSession("rg", "vm", "ip", "pub", "priv", False, "user", "cert", "port", "client/folder", [], None, "Microsoft.HybridCompute", None, None, True) session.proxy_path = "proxy_path" self.assertEqual(["-o", "ProxyCommand=\"proxy_path\" -p port", "-i", "priv_path", "-o", "CertificateFile=\"cert_path\""], session.build_args()) diff --git a/src/ssh/azext_ssh/tests/latest/test_ssh_utils.py b/src/ssh/azext_ssh/tests/latest/test_ssh_utils.py index e0adb1fa14d..26877814d08 100644 --- a/src/ssh/azext_ssh/tests/latest/test_ssh_utils.py +++ b/src/ssh/azext_ssh/tests/latest/test_ssh_utils.py @@ -15,13 +15,13 @@ class SSHUtilsTests(unittest.TestCase): @mock.patch.object(ssh_utils, '_start_cleanup') @mock.patch.object(ssh_utils, '_terminate_cleanup') - @mock.patch.object(ssh_utils, '_get_ssh_client_path') + @mock.patch.object(ssh_utils, 'get_ssh_client_path') @mock.patch('subprocess.run') @mock.patch('os.environ.copy') @mock.patch('platform.system') def test_start_ssh_connection_compute(self, mock_system, mock_copy_env, mock_call, mock_path, mock_terminatecleanup, mock_startcleanup): - op_info = ssh_info.SSHSession("rg", "vm", "ip", None, None, False, "user", None, "port", None, ['arg1', 'arg2', 'arg3'], False, "Microsof.Compute", None, None) + op_info = ssh_info.SSHSession("rg", "vm", "ip", None, None, False, "user", None, "port", None, ['arg1', 'arg2', 'arg3'], False, "Microsof.Compute", None, None, False) op_info.public_key_file = "pub" op_info.private_key_file = "priv" op_info.cert_file = "cert" @@ -44,13 +44,13 @@ def test_start_ssh_connection_compute(self, mock_system, mock_copy_env, mock_cal @mock.patch.object(ssh_utils, '_terminate_cleanup') @mock.patch('os.environ.copy') - @mock.patch.object(ssh_utils, '_get_ssh_client_path') + @mock.patch.object(ssh_utils, 'get_ssh_client_path') @mock.patch('subprocess.run') @mock.patch('azext_ssh.custom.connectivity_utils.format_relay_info_string') @mock.patch('platform.system') def test_start_ssh_connection_arc(self, mock_system, mock_relay_str, mock_call, mock_path, mock_copy_env, mock_terminatecleanup): - op_info = ssh_info.SSHSession("rg", "vm", None, None, None, False, "user", None, "port", None, ['arg1'], False, "Microsoft.HybridCompute", None, None) + op_info = ssh_info.SSHSession("rg", "vm", None, None, None, False, "user", None, "port", None, ['arg1'], False, "Microsoft.HybridCompute", None, None, False) op_info.public_key_file = "pub" op_info.private_key_file = "priv" op_info.cert_file = "cert" @@ -147,7 +147,7 @@ def test_get_ssh_client_path_with_client_folder_non_windows(self, mock_isfile, m mock_join.return_value = "ssh_path" mock_system.return_value = "Linux" mock_isfile.return_value = True - actual_path = ssh_utils._get_ssh_client_path(ssh_client_folder='/client/folder') + actual_path = ssh_utils.get_ssh_client_path(ssh_client_folder='/client/folder') self.assertEqual(actual_path, "ssh_path") mock_join.assert_called_once_with('/client/folder', 'ssh') mock_isfile.assert_called_once_with("ssh_path") @@ -159,7 +159,7 @@ def test_get_ssh_client_path_with_client_folder_windows(self, mock_isfile, mock_ mock_join.return_value = "ssh_keygen_path" mock_system.return_value = "Windows" mock_isfile.return_value = True - actual_path = ssh_utils._get_ssh_client_path(ssh_command='ssh-keygen', ssh_client_folder='/client/folder') + actual_path = ssh_utils.get_ssh_client_path(ssh_command='ssh-keygen', ssh_client_folder='/client/folder') self.assertEqual(actual_path, "ssh_keygen_path.exe") mock_join.assert_called_once_with('/client/folder', 'ssh-keygen') mock_isfile.assert_called_once_with("ssh_keygen_path.exe") @@ -171,7 +171,7 @@ def test_get_ssh_client_path_with_client_folder_no_file(self, mock_isfile, mock_ mock_join.return_value = "ssh_path" mock_system.return_value = "Mac" mock_isfile.return_value = False - actual_path = ssh_utils._get_ssh_client_path(ssh_client_folder='/client/folder') + actual_path = ssh_utils.get_ssh_client_path(ssh_client_folder='/client/folder') self.assertEqual(actual_path, "ssh") mock_join.assert_called_once_with('/client/folder', 'ssh') mock_isfile.assert_called_once_with("ssh_path") @@ -179,7 +179,7 @@ def test_get_ssh_client_path_with_client_folder_no_file(self, mock_isfile, mock_ @mock.patch('platform.system') def test_get_ssh_client_preinstalled_non_windows(self, mock_system): mock_system.return_value = "Mac" - actual_path = ssh_utils._get_ssh_client_path() + actual_path = ssh_utils.get_ssh_client_path() self.assertEqual('ssh', actual_path) mock_system.assert_called_once_with() @@ -211,7 +211,7 @@ def _test_get_ssh_client_path_preinstalled_windows(self, platform_arch, os_arch, mock.call("system32path", "openSSH", "ssh.exe") ] - actual_path = ssh_utils._get_ssh_client_path() + actual_path = ssh_utils.get_ssh_client_path() self.assertEqual("sshfilepath", actual_path) mock_system.assert_called_once_with() @@ -233,4 +233,4 @@ def test_get_ssh_path_windows_ssh_preinstalled_not_found(self, mock_isfile, mock mock_environ.__getitem__.return_value = "rootpath" mock_isfile.return_value = False - self.assertRaises(azclierror.UnclassifiedUserFault, ssh_utils._get_ssh_client_path) + self.assertRaises(azclierror.UnclassifiedUserFault, ssh_utils.get_ssh_client_path)