diff --git a/src/ssh/azext_ssh/custom.py b/src/ssh/azext_ssh/custom.py index 1fd295d2c3a..b3d51a6b89b 100644 --- a/src/ssh/azext_ssh/custom.py +++ b/src/ssh/azext_ssh/custom.py @@ -65,7 +65,7 @@ def ssh_cert(cmd, cert_path=None, public_key_file=None): keys_folder = os.path.dirname(cert_path) logger.warning("The generated SSH keys are stored at %s. Please delete SSH keys when the certificate " "is no longer being used.", keys_folder) - public_key_file, _, _ = _check_or_create_public_private_files(public_key_file, None, keys_folder) + public_key_file, _ = _check_or_create_public_private_files(public_key_file, None, keys_folder) cert_file, _ = _get_and_write_certificate(cmd, public_key_file, cert_path) print(cert_file + "\n") @@ -87,9 +87,11 @@ def _do_ssh_op(cmd, resource_group, vm_name, ssh_ip, public_key_file, private_ke # If user provides a local user, use the provided credentials for authentication if not username: delete_cert = True - public_key_file, private_key_file, delete_keys = _check_or_create_public_private_files(public_key_file, - private_key_file, - credentials_folder) + if not public_key_file and not private_key_file: + delete_keys = True + public_key_file, private_key_file = _check_or_create_public_private_files(public_key_file, + private_key_file, + credentials_folder) cert_file, username = _get_and_write_certificate(cmd, public_key_file, None) op_call(ssh_ip, username, cert_file, private_key_file, delete_keys, delete_cert) @@ -172,13 +174,9 @@ def _assert_args(resource_group, vm_name, ssh_ip, cert_file, username): raise azclierror.FileOperationError(f"Certificate file {cert_file} not found") -def _check_or_create_public_private_files(public_key_file, private_key_file, credentials_folder): - delete_keys = False +def _check_or_create_public_private_files(public_key_file, private_key_file, credentials_folder=None): # If nothing is passed, then create a directory with a ephemeral keypair if not public_key_file and not private_key_file: - # We only want to delete the keys if the user hasn't provided their own keys - # Only ssh vm deletes generated keys. - delete_keys = True if not credentials_folder: # az ssh vm: Create keys on temp folder and delete folder once connection succeeds/fails. credentials_folder = tempfile.mkdtemp(prefix="aadsshcert") @@ -206,7 +204,7 @@ def _check_or_create_public_private_files(public_key_file, private_key_file, cre if not os.path.isfile(private_key_file): raise azclierror.FileOperationError(f"Private key file {private_key_file} not found") - return public_key_file, private_key_file, delete_keys + return public_key_file, private_key_file def _write_cert_file(certificate_contents, cert_file):