Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 37 additions & 7 deletions src/ssh/azext_ssh/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,8 @@ def ssh_arc(cmd, resource_group_name=None, vm_name=None, public_key_file=None, p
raise azclierror.ArgumentUsageError("Can't use --delete-private-key outside an Azure Cloud Shell session.")

_assert_args(resource_group_name, vm_name, None, "Microsoft.HybridCompute", cert_file, local_user)

credentials_folder = None

_validate_arc_server(cmd, resource_group_name, vm_name)
op_call = functools.partial(ssh_utils.start_ssh_connection, ssh_client_path=ssh_client_path, ssh_args=ssh_args,
delete_credentials=delete_credentials)
_do_ssh_op(cmd, vm_name, resource_group_name, None, public_key_file, private_key_file, local_user, cert_file, port,
Expand All @@ -107,8 +106,8 @@ def _do_ssh_op(cmd, vm_name, resource_group_name, ssh_ip, public_key_file, priva
proxy_path = None
relay_info = None
if is_arc:
proxy_path = _arc_get_client_side_proxy()
relay_info = _arc_list_access_details(cmd, resource_group_name, vm_name)
proxy_path = _arc_get_client_side_proxy()
else:
ssh_ip = ssh_ip or ip_utils.get_ssh_ip(cmd, resource_group_name, vm_name, use_private_ip)
if not ssh_ip:
Expand Down Expand Up @@ -361,18 +360,30 @@ def _arc_get_client_side_proxy():
# Get the Access Details to connect to Arc Connectivity platform from the HybridConnectivity RP
def _arc_list_access_details(cmd, resource_group, vm_name):
from azext_ssh._client_factory import cf_endpoint
from azure.core.exceptions import ResourceNotFoundError
client = cf_endpoint(cmd.cli_ctx)
try:
t0 = time.time()
result = client.list_credentials(resource_group_name=resource_group, machine_name=vm_name,
endpoint_name="default")
time_elapsed = time.time() - t0
telemetry.add_extension_event('ssh', {'Context.Default.AzureCLI.SSHListCredentialsTime': time_elapsed})
except ResourceNotFoundError as e:
telemetry.set_exception(exception='Call to listCredentials failed',
fault_type=consts.LIST_CREDENTIALS_FAILED_FAULT_TYPE,
summary=f'listCredentials failed with error: {str(e)}.')
# Probably indicates that the endpoint doesn't exist. Recommend they restart the agent.
raise azclierror.ResourceNotFoundError("Request for Azure Relay Information failed with a "
"ResourceNotFoundError.\n"
"Make sure the HybridConnectivity RP is registered and the "
"HybridConnectivity feature flag is enabled. "
"If error persists, please restart hybrid agent on target machine."
f"\nError:\n{str(e)}")
except Exception as e:
telemetry.set_exception(exception='Call to listCredentials failed',
fault_type=consts.LIST_CREDENTIALS_FAILED_FAULT_TYPE,
summary=f'listCredentials failed with error: {str(e)}.')
raise azclierror.ClientRequestError(f"Request for Azure Relay Information Failed: {str(e)}")
raise azclierror.ClientRequestError(f"Request for Azure Relay Information Failed. Error:\n{str(e)}")

result_string = json.dumps(
{
Expand All @@ -392,7 +403,6 @@ def _arc_list_access_details(cmd, resource_group, vm_name):

def _decide_op_call(cmd, resource_group_name, vm_name, ssh_ip, resource_type, config_path, overwrite,
ssh_client_path, ssh_args, delete_credentials, credentials_folder):

# If the user provides an IP address the target will be treated as an Azure VM even if it is an
# Arc Server. Which just means that the Connectivity Proxy won't be used to establish connection.
is_arc_server = False
Expand All @@ -402,7 +412,7 @@ def _decide_op_call(cmd, resource_group_name, vm_name, ssh_ip, resource_type, co

elif resource_type:
if resource_type == "Microsoft.HybridCompute":
is_arc_server = True
is_arc_server = _validate_arc_server(cmd, resource_group_name, vm_name)

else:
vm_error, is_azure_vm = _check_if_azure_vm(cmd, resource_group_name, vm_name)
Expand All @@ -415,7 +425,7 @@ def _decide_op_call(cmd, resource_group_name, vm_name, ssh_ip, resource_type, co
from azure.core.exceptions import ResourceNotFoundError
if isinstance(arc_error, ResourceNotFoundError) and isinstance(vm_error, ResourceNotFoundError):
raise azclierror.ResourceNotFoundError(f"The resource {vm_name} in the resource group "
f"{resource_group_name} was not found. Erros:\n"
f"{resource_group_name} was not found. Errors:\n"
f"{str(arc_error)}\n{str(vm_error)}")
raise azclierror.BadRequestError("Unable to determine the target machine type as Azure VM or "
f"Arc Server. Errors:\n{str(arc_error)}\n{str(vm_error)}")
Expand Down Expand Up @@ -457,3 +467,23 @@ def _check_if_arc_server(cmd, resource_group_name, vm_name):
except HttpResponseError as e:
return e, False
return None, True


def _validate_arc_server(cmd, resource_group_name, vm_name):
from azure.core.exceptions import ResourceNotFoundError, HttpResponseError
arc_error, is_arc_server = _check_if_arc_server(cmd, resource_group_name, vm_name)
if not is_arc_server:
if isinstance(arc_error, ResourceNotFoundError):
raise azclierror.ResourceNotFoundError(f"The resource {vm_name} in the resource group "
f"{resource_group_name} was not found. Error:\n"
f"{str(arc_error)}")
if isinstance(arc_error, HttpResponseError) and arc_error.status_code == 403:
raise azclierror.ClientRequestError(f"{str(arc_error)}\n\n"
"Make sure you have owner or contributor role. "
"If you have the correct role assigned, please ensure "
"you are using the correct subscription "
"by running \"az account show\". "
"Run \"az account set -s <subscription id>\" to activate "
f"the correct subscription.\n")
raise azclierror.BadRequestError(f"Arc Server validation failed.\nError:\n{str(arc_error)}")
return is_arc_server
12 changes: 9 additions & 3 deletions src/ssh/azext_ssh/tests/latest/test_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,10 @@ def test_ssh_config_credentials_folder_and_key(self):

@mock.patch('azext_ssh.custom._assert_args')
@mock.patch('azext_ssh.custom._do_ssh_op')
def test_ssh_arc(self, mock_do_op, mock_assert_args):
@mock.patch('azext_ssh.custom._validate_arc_server')
def test_ssh_arc(self, mock_validate, mock_do_op, mock_assert_args):
cmd = mock.Mock()
mock_validate.return_value = True
custom.ssh_arc(cmd, 'rg', 'vm', 'public', 'private', 'user', 'cert', 'port', 'path/to/ssh', False, 'ssh_args')
mock_assert_args.assert_called_once_with('rg', 'vm', None, 'Microsoft.HybridCompute', 'cert', 'user')
mock_do_op.assert_called_once_with(cmd, 'vm', 'rg', None, 'public', 'private', 'user', 'cert', 'port', False, None, mock.ANY, True)
Expand Down Expand Up @@ -115,8 +117,10 @@ def test_delete_credentials_not_cloudshell(self, mock_getenv):
@mock.patch('azext_ssh.custom._assert_args')
@mock.patch('azext_ssh.custom._do_ssh_op')
@mock.patch('os.environ.get')
def test_ssh_arc_delete_credentials_cloudshell(self, mock_getenv, mock_do_op, mock_assert_args):
@mock.patch('azext_ssh.custom._validate_arc_server')
def test_ssh_arc_delete_credentials_cloudshell(self, mock_validate, mock_getenv, mock_do_op, mock_assert_args):
mock_getenv.return_value = "cloud-shell/1.0"
mock_validate.return_value = True
cmd = mock.Mock()
custom.ssh_arc(cmd, 'rg', 'vm', 'public', 'private', 'user', 'cert', 'port', 'path/to/ssh', True, 'ssh_args')
mock_assert_args.assert_called_once_with('rg', 'vm', None, 'Microsoft.HybridCompute', 'cert', 'user')
Expand Down Expand Up @@ -214,8 +218,10 @@ def test_decide_op_call_config_with_name_and_resource_group(self, mock_check_arc
mock_check_arc.assert_called_once_with(cmd, 'rg', 'vm')
mock_check_az_vm.assert_called_once_with(cmd, 'rg', 'vm')

def test_decide_op_call_arc_with_resource_type(self):
@mock.patch('azext_ssh.custom._validate_arc_server')
def test_decide_op_call_arc_with_resource_type(self, mock_validate):
cmd = mock.Mock()
mock_validate.return_value = True
expected_result = functools.partial(custom._do_ssh_op, is_arc=True,
op_call=functools.partial(ssh_utils.start_ssh_connection, ssh_client_path='path', ssh_args='args', delete_credentials=True))
result = custom._decide_op_call(cmd, 'rg', 'vm', None, 'Microsoft.HybridCompute', None, False, "path", "args", True, None)
Expand Down