diff --git a/src/ssh/azext_ssh/custom.py b/src/ssh/azext_ssh/custom.py index 8ae5feb3e4f..1709e3e288b 100644 --- a/src/ssh/azext_ssh/custom.py +++ b/src/ssh/azext_ssh/custom.py @@ -92,9 +92,8 @@ def ssh_arc(cmd, resource_group_name=None, vm_name=None, public_key_file=None, p raise azclierror.ArgumentUsageError("Can't use --delete-private-key outside an Azure Cloud Shell session.") _assert_args(resource_group_name, vm_name, None, "Microsoft.HybridCompute", cert_file, local_user) - credentials_folder = None - + _validate_arc_server(cmd, resource_group_name, vm_name) op_call = functools.partial(ssh_utils.start_ssh_connection, ssh_client_path=ssh_client_path, ssh_args=ssh_args, delete_credentials=delete_credentials) _do_ssh_op(cmd, vm_name, resource_group_name, None, public_key_file, private_key_file, local_user, cert_file, port, @@ -107,8 +106,8 @@ def _do_ssh_op(cmd, vm_name, resource_group_name, ssh_ip, public_key_file, priva proxy_path = None relay_info = None if is_arc: - proxy_path = _arc_get_client_side_proxy() relay_info = _arc_list_access_details(cmd, resource_group_name, vm_name) + proxy_path = _arc_get_client_side_proxy() else: ssh_ip = ssh_ip or ip_utils.get_ssh_ip(cmd, resource_group_name, vm_name, use_private_ip) if not ssh_ip: @@ -361,6 +360,7 @@ def _arc_get_client_side_proxy(): # Get the Access Details to connect to Arc Connectivity platform from the HybridConnectivity RP def _arc_list_access_details(cmd, resource_group, vm_name): from azext_ssh._client_factory import cf_endpoint + from azure.core.exceptions import ResourceNotFoundError client = cf_endpoint(cmd.cli_ctx) try: t0 = time.time() @@ -368,11 +368,22 @@ def _arc_list_access_details(cmd, resource_group, vm_name): endpoint_name="default") time_elapsed = time.time() - t0 telemetry.add_extension_event('ssh', {'Context.Default.AzureCLI.SSHListCredentialsTime': time_elapsed}) + except ResourceNotFoundError as e: + telemetry.set_exception(exception='Call to listCredentials failed', + fault_type=consts.LIST_CREDENTIALS_FAILED_FAULT_TYPE, + summary=f'listCredentials failed with error: {str(e)}.') + # Probably indicates that the endpoint doesn't exist. Recommend they restart the agent. + raise azclierror.ResourceNotFoundError("Request for Azure Relay Information failed with a " + "ResourceNotFoundError.\n" + "Make sure the HybridConnectivity RP is registered and the " + "HybridConnectivity feature flag is enabled. " + "If error persists, please restart hybrid agent on target machine." + f"\nError:\n{str(e)}") except Exception as e: telemetry.set_exception(exception='Call to listCredentials failed', fault_type=consts.LIST_CREDENTIALS_FAILED_FAULT_TYPE, summary=f'listCredentials failed with error: {str(e)}.') - raise azclierror.ClientRequestError(f"Request for Azure Relay Information Failed: {str(e)}") + raise azclierror.ClientRequestError(f"Request for Azure Relay Information Failed. Error:\n{str(e)}") result_string = json.dumps( { @@ -392,7 +403,6 @@ def _arc_list_access_details(cmd, resource_group, vm_name): def _decide_op_call(cmd, resource_group_name, vm_name, ssh_ip, resource_type, config_path, overwrite, ssh_client_path, ssh_args, delete_credentials, credentials_folder): - # If the user provides an IP address the target will be treated as an Azure VM even if it is an # Arc Server. Which just means that the Connectivity Proxy won't be used to establish connection. is_arc_server = False @@ -402,7 +412,7 @@ def _decide_op_call(cmd, resource_group_name, vm_name, ssh_ip, resource_type, co elif resource_type: if resource_type == "Microsoft.HybridCompute": - is_arc_server = True + is_arc_server = _validate_arc_server(cmd, resource_group_name, vm_name) else: vm_error, is_azure_vm = _check_if_azure_vm(cmd, resource_group_name, vm_name) @@ -415,7 +425,7 @@ def _decide_op_call(cmd, resource_group_name, vm_name, ssh_ip, resource_type, co from azure.core.exceptions import ResourceNotFoundError if isinstance(arc_error, ResourceNotFoundError) and isinstance(vm_error, ResourceNotFoundError): raise azclierror.ResourceNotFoundError(f"The resource {vm_name} in the resource group " - f"{resource_group_name} was not found. Erros:\n" + f"{resource_group_name} was not found. Errors:\n" f"{str(arc_error)}\n{str(vm_error)}") raise azclierror.BadRequestError("Unable to determine the target machine type as Azure VM or " f"Arc Server. Errors:\n{str(arc_error)}\n{str(vm_error)}") @@ -457,3 +467,23 @@ def _check_if_arc_server(cmd, resource_group_name, vm_name): except HttpResponseError as e: return e, False return None, True + + +def _validate_arc_server(cmd, resource_group_name, vm_name): + from azure.core.exceptions import ResourceNotFoundError, HttpResponseError + arc_error, is_arc_server = _check_if_arc_server(cmd, resource_group_name, vm_name) + if not is_arc_server: + if isinstance(arc_error, ResourceNotFoundError): + raise azclierror.ResourceNotFoundError(f"The resource {vm_name} in the resource group " + f"{resource_group_name} was not found. Error:\n" + f"{str(arc_error)}") + if isinstance(arc_error, HttpResponseError) and arc_error.status_code == 403: + raise azclierror.ClientRequestError(f"{str(arc_error)}\n\n" + "Make sure you have owner or contributor role. " + "If you have the correct role assigned, please ensure " + "you are using the correct subscription " + "by running \"az account show\". " + "Run \"az account set -s \" to activate " + f"the correct subscription.\n") + raise azclierror.BadRequestError(f"Arc Server validation failed.\nError:\n{str(arc_error)}") + return is_arc_server diff --git a/src/ssh/azext_ssh/tests/latest/test_custom.py b/src/ssh/azext_ssh/tests/latest/test_custom.py index 1a1c309adbc..cbd98d048d1 100644 --- a/src/ssh/azext_ssh/tests/latest/test_custom.py +++ b/src/ssh/azext_ssh/tests/latest/test_custom.py @@ -79,8 +79,10 @@ def test_ssh_config_credentials_folder_and_key(self): @mock.patch('azext_ssh.custom._assert_args') @mock.patch('azext_ssh.custom._do_ssh_op') - def test_ssh_arc(self, mock_do_op, mock_assert_args): + @mock.patch('azext_ssh.custom._validate_arc_server') + def test_ssh_arc(self, mock_validate, mock_do_op, mock_assert_args): cmd = mock.Mock() + mock_validate.return_value = True custom.ssh_arc(cmd, 'rg', 'vm', 'public', 'private', 'user', 'cert', 'port', 'path/to/ssh', False, 'ssh_args') mock_assert_args.assert_called_once_with('rg', 'vm', None, 'Microsoft.HybridCompute', 'cert', 'user') mock_do_op.assert_called_once_with(cmd, 'vm', 'rg', None, 'public', 'private', 'user', 'cert', 'port', False, None, mock.ANY, True) @@ -115,8 +117,10 @@ def test_delete_credentials_not_cloudshell(self, mock_getenv): @mock.patch('azext_ssh.custom._assert_args') @mock.patch('azext_ssh.custom._do_ssh_op') @mock.patch('os.environ.get') - def test_ssh_arc_delete_credentials_cloudshell(self, mock_getenv, mock_do_op, mock_assert_args): + @mock.patch('azext_ssh.custom._validate_arc_server') + def test_ssh_arc_delete_credentials_cloudshell(self, mock_validate, mock_getenv, mock_do_op, mock_assert_args): mock_getenv.return_value = "cloud-shell/1.0" + mock_validate.return_value = True cmd = mock.Mock() custom.ssh_arc(cmd, 'rg', 'vm', 'public', 'private', 'user', 'cert', 'port', 'path/to/ssh', True, 'ssh_args') mock_assert_args.assert_called_once_with('rg', 'vm', None, 'Microsoft.HybridCompute', 'cert', 'user') @@ -214,8 +218,10 @@ def test_decide_op_call_config_with_name_and_resource_group(self, mock_check_arc mock_check_arc.assert_called_once_with(cmd, 'rg', 'vm') mock_check_az_vm.assert_called_once_with(cmd, 'rg', 'vm') - def test_decide_op_call_arc_with_resource_type(self): + @mock.patch('azext_ssh.custom._validate_arc_server') + def test_decide_op_call_arc_with_resource_type(self, mock_validate): cmd = mock.Mock() + mock_validate.return_value = True expected_result = functools.partial(custom._do_ssh_op, is_arc=True, op_call=functools.partial(ssh_utils.start_ssh_connection, ssh_client_path='path', ssh_args='args', delete_credentials=True)) result = custom._decide_op_call(cmd, 'rg', 'vm', None, 'Microsoft.HybridCompute', None, False, "path", "args", True, None)