diff --git a/src/ssh/azext_ssh/custom.py b/src/ssh/azext_ssh/custom.py index 1fd295d2c3a..437d8f80e43 100644 --- a/src/ssh/azext_ssh/custom.py +++ b/src/ssh/azext_ssh/custom.py @@ -63,11 +63,16 @@ def ssh_cert(cmd, cert_path=None, public_key_file=None): keys_folder = None if not public_key_file: keys_folder = os.path.dirname(cert_path) - logger.warning("The generated SSH keys are stored at %s. Please delete SSH keys when the certificate " - "is no longer being used.", keys_folder) + print(f"The generated SSH keys are saved in {keys_folder}. Please delete SSH keys when the certificate " + "is no longer being used.") public_key_file, _, _ = _check_or_create_public_private_files(public_key_file, None, keys_folder) cert_file, _ = _get_and_write_certificate(cmd, public_key_file, cert_path) - print(cert_file + "\n") + try: + cert_expiration = ssh_utils.get_certificate_start_and_end_times(cert_file)[1] + print(f"Generated SSH certificate {cert_file} is valid until {cert_expiration} in local time.") + except Exception as e: + logger.warning("Couldn't determine certificate validity. Error: %s", str(e)) + print(cert_file + "\n") def _do_ssh_op(cmd, resource_group, vm_name, ssh_ip, public_key_file, private_key_file, use_private_ip, diff --git a/src/ssh/azext_ssh/ssh_utils.py b/src/ssh/azext_ssh/ssh_utils.py index bf27741aa7a..48cc37a2840 100644 --- a/src/ssh/azext_ssh/ssh_utils.py +++ b/src/ssh/azext_ssh/ssh_utils.py @@ -7,6 +7,7 @@ import subprocess import time import multiprocessing as mp +import datetime from azext_ssh import file_utils from knack import log @@ -91,14 +92,26 @@ def get_ssh_cert_principals(cert_file): return principals -def get_ssh_cert_validity(cert_file): - info = get_ssh_cert_info(cert_file) - for line in info: - if "Valid:" in line: - return line.strip() +def _get_ssh_cert_validity(cert_file): + if cert_file: + info = get_ssh_cert_info(cert_file) + for line in info: + if "Valid:" in line: + return line.strip() return None +def get_certificate_start_and_end_times(cert_file): + validity_str = _get_ssh_cert_validity(cert_file) + times = None + if validity_str and "Valid: from " in validity_str and " to " in validity_str: + times = validity_str.replace("Valid: from ", "").split(" to ") + t0 = datetime.datetime.strptime(times[0], '%Y-%m-%dT%X') + t1 = datetime.datetime.strptime(times[1], '%Y-%m-%dT%X') + times = (t0, t1) + return times + + def write_ssh_config(config_path, resource_group, vm_name, overwrite, port, ip, username, cert_file, private_key_file, delete_keys, delete_cert): @@ -110,13 +123,18 @@ def write_ssh_config(config_path, resource_group, vm_name, overwrite, port, if not delete_keys: path_to_delete = cert_file items_to_delete = "" - validity = get_ssh_cert_validity(cert_file) - validity_warning = "" - if validity: - validity_warning = f" {validity.lower()}" - logger.warning("%s contains sensitive information%s%s\n" - "Please delete it once you no longer need this config file. ", - path_to_delete, items_to_delete, validity_warning) + + expiration = None + try: + expiration = get_certificate_start_and_end_times(cert_file)[1] + expiration = expiration.strftime("%Y-%m-%d %I:%M:%S %p") + except Exception as e: + logger.warning("Couldn't determine certificate expiration. Error: %s", str(e)) + + if expiration: + print(f"The generated certificate {cert_file} is valid until {expiration} in local time.") + print(f"{path_to_delete} contains sensitive information{items_to_delete}. " + "Please delete it once you no longer this config file.") lines = [""] @@ -125,9 +143,9 @@ def write_ssh_config(config_path, resource_group, vm_name, overwrite, port, lines.append("\tUser " + username) lines.append("\tHostName " + ip) if cert_file: - lines.append("\tCertificateFile " + cert_file) + lines.append("\tCertificateFile \"" + cert_file + "\"") if private_key_file: - lines.append("\tIdentityFile " + private_key_file) + lines.append("\tIdentityFile \"" + private_key_file + "\"") if port: lines.append("\tPort " + port) @@ -138,9 +156,9 @@ def write_ssh_config(config_path, resource_group, vm_name, overwrite, port, lines.append("Host " + ip) lines.append("\tUser " + username) if cert_file: - lines.append("\tCertificateFile " + cert_file) + lines.append("\tCertificateFile \"" + cert_file + "\"") if private_key_file: - lines.append("\tIdentityFile " + private_key_file) + lines.append("\tIdentityFile \"" + private_key_file + "\"") if port: lines.append("\tPort " + port) @@ -188,7 +206,7 @@ def _build_args(cert_file, private_key_file, port): if port: port_arg = ["-p", port] if cert_file: - certificate = ["-o", "CertificateFile=" + cert_file] + certificate = ["-o", "CertificateFile=\"" + cert_file + "\""] return private_key + certificate + port_arg diff --git a/src/ssh/setup.py b/src/ssh/setup.py index cac0b4129a5..0701d0894c5 100644 --- a/src/ssh/setup.py +++ b/src/ssh/setup.py @@ -7,7 +7,7 @@ from setuptools import setup, find_packages -VERSION = "1.0.0" +VERSION = "1.0.1" CLASSIFIERS = [ 'Development Status :: 4 - Beta',