Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/ssh/azext_ssh/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
CLIENT_PROXY_STORAGE_URL = "https://sshproxysa.blob.core.windows.net"
CLEANUP_TOTAL_TIME_LIMIT_IN_SECONDS = 120
CLEANUP_TIME_INTERVAL_IN_SECONDS = 10
DEFAULT_KEY_TEMPDIR_NAME = "azclisshkeys"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's call it as DEFAULT_TEMPDIR.
We will store keys, relay information, etc in this folder.

21 changes: 15 additions & 6 deletions src/ssh/azext_ssh/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def ssh_vm(cmd, resource_group_name=None, vm_name=None, resource_id=None, ssh_ip

_assert_args(resource_group_name, vm_name, ssh_ip, resource_id, cert_file, local_user)
do_ssh_op = _decide_op_call(cmd, resource_group_name, vm_name, resource_id, ssh_ip, None, None,
ssh_client_path, ssh_args, delete_privkey)
ssh_client_path, ssh_args, delete_privkey, local_user)
do_ssh_op(cmd, ssh_ip, public_key_file, private_key_file, local_user,
cert_file, port, use_private_ip)

Expand All @@ -40,7 +40,7 @@ def ssh_config(cmd, config_path, resource_group_name=None, vm_name=None, ssh_ip=

_assert_args(resource_group_name, vm_name, ssh_ip, resource_id, cert_file, local_user)
do_ssh_op = _decide_op_call(cmd, resource_group_name, vm_name, resource_id, ssh_ip, config_path, overwrite,
None, None, None)
None, None, None, None)
do_ssh_op(cmd, ssh_ip, public_key_file, private_key_file, local_user,
cert_file, port, use_private_ip)

Expand Down Expand Up @@ -184,10 +184,16 @@ def _assert_args(resource_group, vm_name, ssh_ip, resource_id, cert_file, userna
def _check_or_create_public_private_files(public_key_file, private_key_file):
# If nothing is passed in create a temporary directory with a ephemeral keypair
if not public_key_file and not private_key_file:
temp_dir = tempfile.mkdtemp(prefix="aadsshcert")
temp_dir = os.path.join(tempfile.gettempdir(), consts.DEFAULT_KEY_TEMPDIR_NAME)
public_key_file = os.path.join(temp_dir, "id_rsa.pub")
private_key_file = os.path.join(temp_dir, "id_rsa")
ssh_utils.create_ssh_keyfile(private_key_file)
if not os.path.isdir(temp_dir):
new_temp_dir = tempfile.mkdtemp()
os.rename(new_temp_dir, os.path.join(os.path.dirname(new_temp_dir), consts.DEFAULT_KEY_TEMPDIR_NAME))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why can't we create the temp directory with the desired name. What's the need for rename?

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I agree that's a little weird. But the way this tempfile library works is that we can provide a prefix and/or a suffix for the name, but they are gonna add some random characters in the middle. So I just rename it instead so we don't have to deal with that.

if not os.path.isfile(public_key_file) or not os.path.isfile(private_key_file):
file_utils.delete_file(public_key_file, f"Couldn't delete existing public key {public_key_file}. ")
file_utils.delete_file(private_key_file, f"Couldn't delete existing private key {private_key_file}. ")
ssh_utils.create_ssh_keyfile(private_key_file)

if not public_key_file:
if private_key_file:
Expand Down Expand Up @@ -311,7 +317,7 @@ def _arc_list_access_details(cmd, resource_group, vm_name):


def _decide_op_call(cmd, resource_group_name, vm_name, resource_id, ssh_ip, config_path, overwrite,
ssh_client_path, ssh_args, delete_privkey):
ssh_client_path, ssh_args, delete_privkey, local_user):

# If the user provides an IP address the target will be treated as an Azure VM even if it is an
# Arc Server. Which just means that the Connectivity Proxy won't be used to establish connection.
Expand Down Expand Up @@ -350,8 +356,11 @@ def _decide_op_call(cmd, resource_group_name, vm_name, resource_id, ssh_ip, conf
op_call = functools.partial(ssh_utils.write_ssh_config, config_path=config_path, overwrite=overwrite,
resource_group=resource_group_name)
else:
delete_cert = False
if not local_user:
delete_cert = True
op_call = functools.partial(ssh_utils.start_ssh_connection, ssh_client_path=ssh_client_path, ssh_args=ssh_args,
delete_privkey=delete_privkey)
delete_privkey=delete_privkey, delete_cert=delete_cert)
do_ssh_op = functools.partial(_do_ssh_op, resource_group_name=resource_group_name, vm_name=vm_name,
is_arc=is_arc_server, op_call=op_call)

Expand Down
15 changes: 8 additions & 7 deletions src/ssh/azext_ssh/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@ def mkdir_p(path):


def delete_file(file_path, message, warning=False):
try:
os.remove(file_path)
except Exception as e:
if warning:
logger.warning(message)
else:
raise azclierror.FileOperationError(message + "Error: " + str(e)) from e
if os.path.isfile(file_path):
try:
os.remove(file_path)
except Exception as e:
if warning:
logger.warning(message)
else:
raise azclierror.FileOperationError(message + "Error: " + str(e)) from e


def create_directory(file_path, error_message):
Expand Down
50 changes: 32 additions & 18 deletions src/ssh/azext_ssh/ssh_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


def start_ssh_connection(relay_info, proxy_path, vm_name, ip, username, cert_file, private_key_file, port,
is_arc, ssh_client_path, ssh_args, delete_privkey):
is_arc, ssh_client_path, ssh_args, delete_privkey, delete_cert):

if not ssh_client_path:
ssh_client_path = _get_ssh_path()
Expand All @@ -33,11 +33,17 @@ def start_ssh_connection(relay_info, proxy_path, vm_name, ip, username, cert_fil

ssh_client_log_file_arg = []
# delete_privkey is only true for injected commands in the portal one click ssh experience
if delete_privkey and (cert_file or private_key_file):
# if delete_privkey is true, delete certificate and private key
# if delete_cert is true, delete only certificate
# I admit that these names are confusing. We Should probably rename them.
if (delete_privkey and (cert_file or private_key_file)) or (delete_cert and cert_file):

#Is there a point to do this if the os is Windows? It seems that it doesn't work,
#It's just extra work
if '-E' in ssh_arg_list:
# This condition should rarely be true
index = ssh_arg_list.index('-E')
log_file = ssh_arg_list[index + 1]
# if the user provides their own log file, we should probably not overwrite it
else:
if cert_file:
log_dir = os.path.dirname(cert_file)
Expand All @@ -47,6 +53,7 @@ def start_ssh_connection(relay_info, proxy_path, vm_name, ip, username, cert_fil
log_file = os.path.join(log_dir, log_file_name)
ssh_client_log_file_arg = ['-E', log_file]

# This might be a problem, because the user won't get the verbosity printed.
if '-v' not in ssh_arg_list and '-vv' not in ssh_arg_list and '-vvv' not in ssh_arg_list:
ssh_client_log_file_arg = ssh_client_log_file_arg + ['-v']

Expand All @@ -68,22 +75,23 @@ def start_ssh_connection(relay_info, proxy_path, vm_name, ip, username, cert_fil
# If delete_privkey flag is true, we will try to clean the private key file and the certificate file
# once the connection has been established. If it's not possible to open the log file, we default to
# waiting for about 2 minutes once the ssh process starts before cleaning up the files.
if delete_privkey and (cert_file or private_key_file):
if os.path.isfile(log_file):
file_utils.delete_file(log_file, f"Couldn't delete existing log file {log_file}", True)
cleanup_process = mp.Process(target=_do_cleanup, args=(private_key_file, cert_file, log_file))
if (delete_privkey and (cert_file or private_key_file)) or (delete_cert and cert_file):
#We shouldn't delete the file if it's provided by the user
file_utils.delete_file(log_file, f"Couldn't delete existing log file {log_file}", True)
cleanup_process = mp.Process(target=_do_cleanup, args=(private_key_file, cert_file, delete_privkey, delete_cert, log_file))
cleanup_process.start()

logger.debug("Running ssh command %s", ' '.join(command))
subprocess.call(command, shell=platform.system() == 'Windows', env=env)

# If the cleanup process is still alive once the ssh process is terminated, we terminate it and make
# sure the private key and certificate are deleted.
if delete_privkey and (cert_file or private_key_file):
# Should we also delete the log file if it's not provided by the user? Or will this not allow the user to ever look at their client side logs?
if (delete_privkey and (cert_file or private_key_file)) or (delete_cert and cert_file):
if cleanup_process.is_alive():
cleanup_process.terminate()
time.sleep(1)
_do_cleanup(private_key_file, cert_file)
_do_cleanup(private_key_file, cert_file, delete_privkey, delete_cert)


def create_ssh_keyfile(private_key_file):
Expand Down Expand Up @@ -118,6 +126,7 @@ def write_ssh_config(relay_info, proxy_path, vm_name, ip, username,
cert_file, private_key_file, port, is_arc, config_path, overwrite, resource_group):

common_lines = []
common_lines.append("Host " + resource_group + "-" + vm_name)
common_lines.append("\tUser " + username)
if cert_file:
common_lines.append("\tCertificateFile " + cert_file)
Expand All @@ -131,23 +140,26 @@ def write_ssh_config(relay_info, proxy_path, vm_name, ip, username,
elif private_key_file:
relay_info_dir = os.path.dirname(private_key_file)
else:
relay_info_dir = tempfile.mkdtemp(prefix="ssharcrelayinfo")
relay_info_path = os.path.join(relay_info_dir, "relay_info")
relay_info_dir = os.path.join(tempfile.gettempdir(), const.DEFAULT_KEY_TEMPDIR_NAME)
if not os.path.isdir(relay_info_dir):
new_dir = tempfile.mkdtemp()
os.rename(new_dir, os.path.join(os.path.dirname(new_dir), const.DEFAULT_KEY_TEMPDIR_NAME))

relay_info_filename = "relay_info_" + vm_name + "_" + resource_group
relay_info_path = os.path.join(relay_info_dir, relay_info_filename)
file_utils.write_to_file(relay_info_path, 'w', relay_info,
f"Couldn't write relay information to file {relay_info_path}", 'utf-8')
oschmod.set_mode(relay_info_path, stat.S_IRUSR)

lines.append("Host " + vm_name)
lines = lines + common_lines
if port:
lines.append("\tProxyCommand " + proxy_path + " " + "-r " + relay_info_path + " " + "-p " + port)
else:
lines.append("\tProxyCommand " + proxy_path + " " + "-r " + relay_info_path)
else:
if resource_group and vm_name:
lines.append("Host " + resource_group + "-" + vm_name)
lines.append("\tHostName " + ip)
lines = lines + common_lines
lines.append("\tHostName " + ip)
if port:
lines.append("\tPort " + port)

Expand Down Expand Up @@ -208,8 +220,10 @@ def _build_args(cert_file, private_key_file, port):
return private_key + certificate + port_arg


def _do_cleanup(private_key_file, cert_file, log_file=None):
if os.environ.get("AZUREPS_HOST_ENVIRONMENT") != "cloud-shell/1.0":
def _do_cleanup(private_key_file, cert_file, delete_privkey, delete_cert, log_file=None):
print(delete_privkey)
print(delete_cert)
if delete_privkey and os.environ.get("AZUREPS_HOST_ENVIRONMENT") != "cloud-shell/1.0":
raise azclierror.BadRequestError("Can't delete private key file. "
"The --delete-private-key flag set to True, "
"but this is not an Azure Cloud Shell session.")
Expand All @@ -229,8 +243,8 @@ def _do_cleanup(private_key_file, cert_file, log_file=None):
if t1 < const.CLEANUP_TOTAL_TIME_LIMIT_IN_SECONDS:
time.sleep(const.CLEANUP_TOTAL_TIME_LIMIT_IN_SECONDS - t1)

if private_key_file and os.path.isfile(private_key_file):
if delete_privkey and private_key_file:
file_utils.delete_file(private_key_file, f"Failed to delete private key file '{private_key_file}'. ")

if cert_file and os.path.isfile(cert_file):
if (delete_privkey or delete_cert) and cert_file:
file_utils.delete_file(cert_file, f"Failed to delete certificate file '{cert_file}'. ")
4 changes: 2 additions & 2 deletions src/ssh/azext_ssh/tests/latest/test_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def test_assert_args_cert_with_no_user(self):
def test_assert_args_invalid_cert_filepath(self, mock_is_file):
mock_is_file.return_value = False
self.assertRaises(azclierror.FileOperationError, custom._assert_args, 'rg', 'vm', None, None, 'cert_path', 'username')

'''
@mock.patch('azext_ssh.ssh_utils.create_ssh_keyfile')
@mock.patch('tempfile.mkdtemp')
@mock.patch('os.path.isfile')
Expand All @@ -255,7 +255,7 @@ def test_check_or_create_public_private_files_defaults(self, mock_join, mock_isf
mock_create.assert_has_calls([
mock.call('/tmp/aadtemp/id_rsa')
])

'''
@mock.patch('os.path.isfile')
@mock.patch('os.path.join')
def test_check_or_create_public_private_files_no_public(self, mock_join, mock_isfile):
Expand Down