Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions src/ssh/azext_ssh/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,16 @@ def ssh_cert(cmd, cert_path=None, public_key_file=None):
keys_folder = None
if not public_key_file:
keys_folder = os.path.dirname(cert_path)
logger.warning("The generated SSH keys are stored at %s. Please delete SSH keys when the certificate "
"is no longer being used.", keys_folder)
print(f"The generated SSH keys are saved in {keys_folder}. Please delete SSH keys when the certificate "
"is no longer being used.")
public_key_file, _, _ = _check_or_create_public_private_files(public_key_file, None, keys_folder)
cert_file, _ = _get_and_write_certificate(cmd, public_key_file, cert_path)
print(cert_file + "\n")
try:
cert_expiration = ssh_utils.get_certificate_start_and_end_times(cert_file)[1]
print(f"Generated SSH certificate {cert_file} is valid until {cert_expiration} in local time.")
except Exception as e:
logger.warning("Couldn't determine certificate validity. Error: %s", str(e))
print(cert_file + "\n")


def _do_ssh_op(cmd, resource_group, vm_name, ssh_ip, public_key_file, private_key_file, use_private_ip,
Expand Down
52 changes: 35 additions & 17 deletions src/ssh/azext_ssh/ssh_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import subprocess
import time
import multiprocessing as mp
import datetime
from azext_ssh import file_utils

from knack import log
Expand Down Expand Up @@ -91,14 +92,26 @@ def get_ssh_cert_principals(cert_file):
return principals


def get_ssh_cert_validity(cert_file):
info = get_ssh_cert_info(cert_file)
for line in info:
if "Valid:" in line:
return line.strip()
def _get_ssh_cert_validity(cert_file):
if cert_file:
info = get_ssh_cert_info(cert_file)
for line in info:
if "Valid:" in line:
return line.strip()
return None


def get_certificate_start_and_end_times(cert_file):
validity_str = _get_ssh_cert_validity(cert_file)
times = None
if validity_str and "Valid: from " in validity_str and " to " in validity_str:
times = validity_str.replace("Valid: from ", "").split(" to ")
t0 = datetime.datetime.strptime(times[0], '%Y-%m-%dT%X')
t1 = datetime.datetime.strptime(times[1], '%Y-%m-%dT%X')
times = (t0, t1)
return times


def write_ssh_config(config_path, resource_group, vm_name, overwrite, port,
ip, username, cert_file, private_key_file, delete_keys, delete_cert):

Expand All @@ -110,13 +123,18 @@ def write_ssh_config(config_path, resource_group, vm_name, overwrite, port,
if not delete_keys:
path_to_delete = cert_file
items_to_delete = ""
validity = get_ssh_cert_validity(cert_file)
validity_warning = ""
if validity:
validity_warning = f" {validity.lower()}"
logger.warning("%s contains sensitive information%s%s\n"
"Please delete it once you no longer need this config file. ",
path_to_delete, items_to_delete, validity_warning)

expiration = None
try:
expiration = get_certificate_start_and_end_times(cert_file)[1]
expiration = expiration.strftime("%Y-%m-%d %I:%M:%S %p")
except Exception as e:
logger.warning("Couldn't determine certificate expiration. Error: %s", str(e))

if expiration:
print(f"The generated certificate {cert_file} is valid until {expiration} in local time.")
print(f"{path_to_delete} contains sensitive information{items_to_delete}. "
"Please delete it once you no longer this config file.")

lines = [""]

Expand All @@ -125,9 +143,9 @@ def write_ssh_config(config_path, resource_group, vm_name, overwrite, port,
lines.append("\tUser " + username)
lines.append("\tHostName " + ip)
if cert_file:
lines.append("\tCertificateFile " + cert_file)
lines.append("\tCertificateFile \"" + cert_file + "\"")
if private_key_file:
lines.append("\tIdentityFile " + private_key_file)
lines.append("\tIdentityFile \"" + private_key_file + "\"")
if port:
lines.append("\tPort " + port)

Expand All @@ -138,9 +156,9 @@ def write_ssh_config(config_path, resource_group, vm_name, overwrite, port,
lines.append("Host " + ip)
lines.append("\tUser " + username)
if cert_file:
lines.append("\tCertificateFile " + cert_file)
lines.append("\tCertificateFile \"" + cert_file + "\"")
if private_key_file:
lines.append("\tIdentityFile " + private_key_file)
lines.append("\tIdentityFile \"" + private_key_file + "\"")
if port:
lines.append("\tPort " + port)

Expand Down Expand Up @@ -188,7 +206,7 @@ def _build_args(cert_file, private_key_file, port):
if port:
port_arg = ["-p", port]
if cert_file:
certificate = ["-o", "CertificateFile=" + cert_file]
certificate = ["-o", "CertificateFile=\"" + cert_file + "\""]
return private_key + certificate + port_arg


Expand Down
2 changes: 1 addition & 1 deletion src/ssh/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from setuptools import setup, find_packages

VERSION = "1.0.0"
VERSION = "1.0.1"

CLASSIFIERS = [
'Development Status :: 4 - Beta',
Expand Down