Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/ssh/azext_ssh/_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def load_arguments(self, _):
c.argument('port', options_list=['--port'], help='SSH port')
c.argument('ssh_client_path', options_list=['--ssh-client-path'],
help='Path to ssh executable. Default to ssh pre-installed if not provided.')
c.argument('delete_privkey', options_list=['--delete-private-key'],
c.argument('delete_credentials', options_list=['--delete-private-key'],
help=('This is an internal argument. This argument is used by Azure Portal to provide a one click '
'SSH login experience in Cloud shell.'),
deprecate_info=c.deprecate(hide=True), action='store_true')
Expand All @@ -41,13 +41,17 @@ def load_arguments(self, _):
help='The username for a local user')
c.argument('overwrite', action='store_true', options_list=['--overwrite'],
help='Overwrites the config file if this flag is set')
c.argument('credentials_folder', options_list=['--keys-destination-folder', '--keys-dest-folder'],
help='Folder where new generated keys will be stored.')
c.argument('port', options_list=['--port'], help='Port to connect to on the remote host.')
c.argument('cert_file', options_list=['--certificate-file', '-c'], help='Path to certificate file')

with self.argument_context('ssh cert') as c:
c.argument('cert_path', options_list=['--file', '-f'],
help='The file path to write the SSH cert to, defaults to public key path with -aadcert.pub appened')
c.argument('public_key_file', options_list=['--public-key-file', '-p'], help='The RSA public key file path')
c.argument('public_key_file', options_list=['--public-key-file', '-p'],
help='The RSA public key file path. If not provided, '
'generated key pair is stored in the same directory as --file.')

with self.argument_context('ssh arc') as c:
c.argument('vm_name', options_list=['--vm-name', '--name', '-n'], help='The name of the Arc Server')
Expand All @@ -60,7 +64,7 @@ def load_arguments(self, _):
c.argument('port', options_list=['--port'], help='Port to connect to on the remote host.')
c.argument('ssh_client_path', options_list=['--ssh-client-path'],
help='Path to ssh executable. Default to ssh pre-installed if not provided.')
c.argument('delete_privkey', options_list=['--delete-private-key'],
c.argument('delete_credentials', options_list=['--delete-private-key'],
help=('This is an internal argument. This argument is used by Azure Portal to provide a one click '
'SSH login experience in Cloud shell.'),
deprecate_info=c.deprecate(hide=True), action='store_true')
Expand Down
1 change: 1 addition & 0 deletions src/ssh/azext_ssh/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
CLIENT_PROXY_STORAGE_URL = "https://sshproxysa.blob.core.windows.net"
CLEANUP_TOTAL_TIME_LIMIT_IN_SECONDS = 120
CLEANUP_TIME_INTERVAL_IN_SECONDS = 10
CLEANUP_AWAIT_TERMINATION_IN_SECONDS = 30
105 changes: 81 additions & 24 deletions src/ssh/azext_ssh/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import stat
from glob import glob

from knack import log
from azure.cli.core import azclierror
from msrestazure import tools

Expand All @@ -22,38 +23,75 @@
from . import constants as consts
from . import file_utils

logger = log.get_logger(__name__)

def ssh_vm(cmd, resource_group_name=None, vm_name=None, resource_id=None, ssh_ip=None, public_key_file=None,
private_key_file=None, use_private_ip=False, local_user=None, cert_file=None, port=None,
ssh_client_path=None, delete_privkey=False, ssh_args=None):
ssh_client_path=None, delete_credentials=False, ssh_args=None):

if delete_credentials and os.environ.get("AZUREPS_HOST_ENVIRONMENT") != "cloud-shell/1.0":
raise azclierror.ArgumentUsageError("Can't use --delete-private-key outside an Azure Cloud Shell session.")

_assert_args(resource_group_name, vm_name, ssh_ip, resource_id, cert_file, local_user)
credentials_folder = None
do_ssh_op = _decide_op_call(cmd, resource_group_name, vm_name, resource_id, ssh_ip, None, None,
ssh_client_path, ssh_args, delete_privkey)
ssh_client_path, ssh_args, delete_credentials, credentials_folder)
do_ssh_op(cmd, ssh_ip, public_key_file, private_key_file, local_user,
cert_file, port, use_private_ip)
cert_file, port, use_private_ip, credentials_folder)


def ssh_config(cmd, config_path, resource_group_name=None, vm_name=None, ssh_ip=None, resource_id=None,
public_key_file=None, private_key_file=None, overwrite=False, use_private_ip=False,
local_user=None, cert_file=None, port=None):
local_user=None, cert_file=None, port=None, credentials_folder=None):

_assert_args(resource_group_name, vm_name, ssh_ip, resource_id, cert_file, local_user)

if (public_key_file or private_key_file) and credentials_folder:
raise azclierror.ArgumentUsageError("--keys-destination-folder can't be used in conjunction with "
"--public-key-file/-p or --private-key-file/-i.")

# Default credential location
if not credentials_folder:
config_folder = os.path.dirname(config_path)
if not os.path.isdir(config_folder):
raise azclierror.InvalidArgumentValueError(f"Config file destination folder {config_folder} "
"does not exist.")
folder_name = ssh_ip
if resource_group_name and vm_name:
folder_name = resource_group_name + "-" + vm_name
elif resource_id:
resource_info = tools.parse_resource_id(resource_id)
folder_name = resource_info['resource_group'] + "-" + resource_info['resource_name']

credentials_folder = os.path.join(config_folder, os.path.join("az_ssh_config", folder_name))

do_ssh_op = _decide_op_call(cmd, resource_group_name, vm_name, resource_id, ssh_ip, config_path, overwrite,
None, None, None)
None, None, False, credentials_folder)
do_ssh_op(cmd, ssh_ip, public_key_file, private_key_file, local_user,
cert_file, port, use_private_ip)
cert_file, port, use_private_ip, credentials_folder)


def ssh_cert(cmd, cert_path=None, public_key_file=None):
public_key_file, _ = _check_or_create_public_private_files(public_key_file, None)
if not cert_path and not public_key_file:
raise azclierror.RequiredArgumentMissingError("--file or --public-key-file must be provided.")
if cert_path and not os.path.isdir(os.path.dirname(cert_path)):
raise azclierror.InvalidArgumentValueError(f"{os.path.dirname(cert_path)} folder doesn't exist")
# If user doesn't provide a public key, save generated key pair to the same folder as --file
keys_folder = None
if not public_key_file:
keys_folder = os.path.dirname(cert_path)
logger.warning("The generated SSH keys are stored at %s. Please delete SSH keys when the certificate "
"is no longer being used.", keys_folder)
public_key_file, _, _ = _check_or_create_public_private_files(public_key_file, None, keys_folder)
cert_file, _ = _get_and_write_certificate(cmd, public_key_file, cert_path)
print(cert_file + "\n")


def ssh_arc(cmd, resource_group_name=None, vm_name=None, resource_id=None, public_key_file=None, private_key_file=None,
local_user=None, cert_file=None, port=None, ssh_client_path=None, delete_privkey=False, ssh_args=None):

local_user=None, cert_file=None, port=None, ssh_client_path=None, delete_credentials=False, ssh_args=None):

if delete_credentials and os.environ.get("AZUREPS_HOST_ENVIRONMENT") != "cloud-shell/1.0":
raise azclierror.ArgumentUsageError("Can't use --delete-private-key outside an Azure Cloud Shell session.")
_assert_args(resource_group_name, vm_name, None, resource_id, cert_file, local_user)

if resource_id:
Expand All @@ -66,14 +104,16 @@ def ssh_arc(cmd, resource_group_name=None, vm_name=None, resource_id=None, publi
resource_group_name = resource_info['resource_group']
vm_name = resource_info['resource_name']

credentials_folder = None

op_call = functools.partial(ssh_utils.start_ssh_connection, ssh_client_path=ssh_client_path, ssh_args=ssh_args,
delete_privkey=delete_privkey)
delete_credentials=delete_credentials)
_do_ssh_op(cmd, None, public_key_file, private_key_file, local_user, cert_file, port,
False, resource_group_name, vm_name, op_call, True)
False, credentials_folder, resource_group_name, vm_name, op_call, True)


def _do_ssh_op(cmd, ssh_ip, public_key_file, private_key_file, username,
cert_file, port, use_private_ip, resource_group_name, vm_name, op_call, is_arc):
cert_file, port, use_private_ip, credentials_folder, resource_group_name, vm_name, op_call, is_arc):

proxy_path = None
relay_info = None
Expand All @@ -87,12 +127,18 @@ def _do_ssh_op(cmd, ssh_ip, public_key_file, private_key_file, username,
raise azclierror.ResourceNotFoundError(f"VM '{vm_name}' does not have a public IP address to SSH to")
raise azclierror.ResourceNotFoundError(f"VM '{vm_name}' does not have a public or private IP address to"
"SSH to")


# If user provides local user, no credentials should be deleted.
delete_keys = False
delete_cert = False
if not username:
public_key_file, private_key_file = _check_or_create_public_private_files(public_key_file, private_key_file)
delete_cert = True
public_key_file, private_key_file, delete_keys = _check_or_create_public_private_files(public_key_file,
private_key_file,
credentials_folder)
cert_file, username = _get_and_write_certificate(cmd, public_key_file, None)

op_call(relay_info, proxy_path, vm_name, ssh_ip, username, cert_file, private_key_file, port, is_arc)
op_call(relay_info, proxy_path, vm_name, ssh_ip, username, cert_file, private_key_file, port, is_arc, delete_keys, delete_cert, public_key_file)


def _get_and_write_certificate(cmd, public_key_file, cert_file):
Expand Down Expand Up @@ -181,12 +227,23 @@ def _assert_args(resource_group, vm_name, ssh_ip, resource_id, cert_file, userna
raise azclierror.FileOperationError(f"Certificate file {cert_file} not found")


def _check_or_create_public_private_files(public_key_file, private_key_file):
def _check_or_create_public_private_files(public_key_file, private_key_file, credentials_folder):
delete_keys = False
# If nothing is passed in create a temporary directory with a ephemeral keypair
if not public_key_file and not private_key_file:
temp_dir = tempfile.mkdtemp(prefix="aadsshcert")
public_key_file = os.path.join(temp_dir, "id_rsa.pub")
private_key_file = os.path.join(temp_dir, "id_rsa")
# We only want to delete the keys if the user hasn't provided their own keys
# Only ssh vm deletes generated keys.
delete_keys = True
if not credentials_folder:
# az ssh vm: Create keys on temp folder and delete folder once connection succeeds/fails.
credentials_folder = tempfile.mkdtemp(prefix="aadsshcert")
else:
# az ssh config: Keys saved to the same folder as --file or to --keys-destination-folder.
# az ssh cert: Keys saved to the same folder as --file.
if not os.path.isdir(credentials_folder):
os.makedirs(credentials_folder)
public_key_file = os.path.join(credentials_folder, "id_rsa.pub")
private_key_file = os.path.join(credentials_folder, "id_rsa")
ssh_utils.create_ssh_keyfile(private_key_file)

if not public_key_file:
Expand All @@ -204,7 +261,7 @@ def _check_or_create_public_private_files(public_key_file, private_key_file):
if not os.path.isfile(private_key_file):
raise azclierror.FileOperationError(f"Private key file {private_key_file} not found")

return public_key_file, private_key_file
return public_key_file, private_key_file, delete_keys


def _write_cert_file(certificate_contents, cert_file):
Expand Down Expand Up @@ -319,7 +376,7 @@ def _arc_list_access_details(cmd, resource_group, vm_name):


def _decide_op_call(cmd, resource_group_name, vm_name, resource_id, ssh_ip, config_path, overwrite,
ssh_client_path, ssh_args, delete_privkey):
ssh_client_path, ssh_args, delete_credentials, credentials_folder):

# If the user provides an IP address the target will be treated as an Azure VM even if it is an
# Arc Server. Which just means that the Connectivity Proxy won't be used to establish connection.
Expand Down Expand Up @@ -353,17 +410,17 @@ def _decide_op_call(cmd, resource_group_name, vm_name, resource_id, ssh_ip, conf
from azure.core.exceptions import ResourceNotFoundError
if isinstance(arc_error, ResourceNotFoundError) and isinstance(vm_error, ResourceNotFoundError):
raise azclierror.ResourceNotFoundError(f"The resource {vm_name} in the resource group "
"{resource_group_name} was not found. Erros:\n"
f"{resource_group_name} was not found. Erros:\n"
f"{str(arc_error)}\n{str(vm_error)}")
raise azclierror.BadRequestError("Unable to determine the target machine type as Azure VM or "
f"Arc Server. Errors:\n{str(arc_error)}\n{str(vm_error)}")

if config_path:
op_call = functools.partial(ssh_utils.write_ssh_config, config_path=config_path, overwrite=overwrite,
resource_group=resource_group_name)
resource_group=resource_group_name, credentials_folder=credentials_folder)
else:
op_call = functools.partial(ssh_utils.start_ssh_connection, ssh_client_path=ssh_client_path, ssh_args=ssh_args,
delete_privkey=delete_privkey)
delete_credentials=delete_credentials)
do_ssh_op = functools.partial(_do_ssh_op, resource_group_name=resource_group_name, vm_name=vm_name,
is_arc=is_arc_server, op_call=op_call)

Expand Down
27 changes: 20 additions & 7 deletions src/ssh/azext_ssh/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,26 @@ def mkdir_p(path):


def delete_file(file_path, message, warning=False):
try:
os.remove(file_path)
except Exception as e:
if warning:
logger.warning(message)
else:
raise azclierror.FileOperationError(message + "Error: " + str(e)) from e
if os.path.isfile(file_path):
try:
os.remove(file_path)
except Exception as e:
if warning:
logger.warning(message)
else:
raise azclierror.FileOperationError(message + "Error: " + str(e)) from e



def delete_folder(dir_path, message, warning=False):
if os.path.isdir(dir_path):
try:
os.rmdir(dir_path)
except Exception as e:
if warning:
logger.warning(message)
else:
raise azclierror.FileOperationError(message + "Error: " + str(e)) from e


def create_directory(file_path, error_message):
Expand Down
Loading