diff --git a/src/bastion/HISTORY.rst b/src/bastion/HISTORY.rst index 8c34bccfff8..988dd1b1584 100644 --- a/src/bastion/HISTORY.rst +++ b/src/bastion/HISTORY.rst @@ -3,6 +3,12 @@ Release History =============== +0.2.0 +++++++ +* Adding support for IP connect through AZ CLI. +* Initial support for connectivity through developerSku. +* Bug fixes. + 0.1.0 ++++++ * Initial release. \ No newline at end of file diff --git a/src/bastion/README.md b/src/bastion/README.md index 3912f3a39df..9f55ee0e9c7 100644 --- a/src/bastion/README.md +++ b/src/bastion/README.md @@ -28,3 +28,13 @@ az network bastion show --name MyBastionHost --resource-group MyResourceGroup ```commandline az network bastion update --name MyBastionHost --resource-group MyResourceGroup --enable-tunneling ``` + +### RDP to VM/VMSS using Azure Bastion host machine +```commandline +az network bastion rdp --name MyBastionHost --resource-group MyResourceGroup --target-resource-id ResourceId +``` + +### SSH to VM/VMSS using Azure Bastion host machine +```commandline +az network bastion ssh --name MyBastionHost --resource-group MyResourceGroup --enable-tunneling --target-resource-id ResourceId --auth-type password +``` diff --git a/src/bastion/azext_bastion/BastionServiceConstants.py b/src/bastion/azext_bastion/BastionServiceConstants.py new file mode 100644 index 00000000000..40db1bdec9b --- /dev/null +++ b/src/bastion/azext_bastion/BastionServiceConstants.py @@ -0,0 +1,16 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +# pylint: disable=import-error,unused-import + +from enum import Enum + + +class BastionSku(Enum): + + Basic = "Basic" + Standard = "Standard" + Developer = "Developer" + QuickConnect = "QuickConnect" diff --git a/src/bastion/azext_bastion/_help.py b/src/bastion/azext_bastion/_help.py index 23e67da89f0..83cecad2f2c 100644 --- a/src/bastion/azext_bastion/_help.py +++ b/src/bastion/azext_bastion/_help.py @@ -24,6 +24,9 @@ - name: SSH to virtual machine using Azure Bastion using AAD. text: | az network bastion ssh --name MyBastionHost --resource-group MyResourceGroup --target-resource-id vmResourceId --auth-type AAD + - name: SSH to virtual machine using Azure Bastion using AAD. + text: | + az network bastion ssh --name MyBastionHost --resource-group MyResourceGroup --target-resource-id vmResourceId --auth-type AAD """ helps['network bastion rdp'] = """ @@ -33,13 +36,19 @@ - name: RDP to virtual machine using Azure Bastion. text: | az network bastion rdp --name MyBastionHost --resource-group MyResourceGroup --target-resource-id vmResourceId + - name: RDP to machine using reachable IP address. + text: | + az network bastion rdp --name MyBastionHost --resource-group MyResourceGroup --target-ip-address 10.0.0.1 """ helps['network bastion tunnel'] = """ type: command short-summary: Open a tunnel through Azure Bastion to a target virtual machine. examples: - - name: Open a tunnel through Azure Bastion to a target virtual machine. + - name: Open a tunnel through Azure Bastion to a target virtual machine using resourceId. text: | az network bastion tunnel --name MyBastionHost --resource-group MyResourceGroup --target-resource-id vmResourceId --resource-port 22 --port 50022 + - name: Open a tunnel through Azure Bastion to a target virtual machine using its IP address. + text: | + az network bastion tunnel --name MyBastionHost --resource-group MyResourceGroup --target-ip-address 10.0.0.1 --resource-port 22 --port 50022 """ diff --git a/src/bastion/azext_bastion/_params.py b/src/bastion/azext_bastion/_params.py index 862ff05e65b..565bf887d25 100644 --- a/src/bastion/azext_bastion/_params.py +++ b/src/bastion/azext_bastion/_params.py @@ -10,6 +10,7 @@ from azure.cli.core.commands.parameters import get_resource_name_completion_list, get_three_state_flag from knack.arguments import CLIArgumentType +from ._validators import (validate_ip_address) def load_arguments(self, _): # pylint: disable=unused-argument @@ -24,8 +25,10 @@ def load_arguments(self, _): # pylint: disable=unused-argument c.argument("bastion_host_name", bastion_host_name_type, options_list=["--name", "-n"]) c.argument("resource_port", help="Resource port of the target VM to which the bastion will connect.", options_list=["--resource-port"]) - c.argument("target_resource_id", help="ResourceId of the target Virtual Machine.", + c.argument("target_resource_id", help="ResourceId of the target Virtual Machine.", required=False, options_list=["--target-resource-id"]) + c.argument("target_ip_address", help="IP address of target Virtual Machine.", required=False, + options_list=["--target-ip-address"], validator=validate_ip_address) with self.argument_context("network bastion ssh") as c: c.argument("auth_type", help="Auth type to use for SSH connections.", options_list=["--auth-type"]) diff --git a/src/bastion/azext_bastion/_validators.py b/src/bastion/azext_bastion/_validators.py new file mode 100644 index 00000000000..5e368869e03 --- /dev/null +++ b/src/bastion/azext_bastion/_validators.py @@ -0,0 +1,28 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import ipaddress +from azure.cli.core.azclierror import InvalidArgumentValueError + + +def validate_ip_address(namespace): + if namespace.target_ip_address is not None: + _validate_ip_address_format(namespace) + + +def _validate_ip_address_format(namespace): + if namespace.target_ip_address is not None: + input_value = namespace.target_ip_address + if ' ' in input_value: + raise InvalidArgumentValueError("Spaces not allowed: '{}' ".format(input_value)) + input_ips = input_value.split(',') + if len(input_ips) > 8: + raise InvalidArgumentValueError('Maximum 8 IP addresses are allowed per rule.') + validated_ips = '' + for ip in input_ips: + # Use ipaddress library to validate ip network format + ip_obj = ipaddress.ip_network(ip) + validated_ips += str(ip_obj) + ',' + namespace.target_ip_address = validated_ips[:-1] diff --git a/src/bastion/azext_bastion/custom.py b/src/bastion/azext_bastion/custom.py index a3ae521b72f..e6f6712625d 100644 --- a/src/bastion/azext_bastion/custom.py +++ b/src/bastion/azext_bastion/custom.py @@ -18,9 +18,10 @@ import requests from azure.cli.core.azclierror import ValidationError, InvalidArgumentValueError, RequiredArgumentMissingError, \ UnrecognizedArgumentError, CLIInternalError, ClientRequestError +from azure.cli.core.commands.client_factory import get_subscription_id from knack.log import get_logger from msrestazure.tools import is_valid_resource_id - +from .BastionServiceConstants import BastionSku from .aaz.latest.network.bastion import Create as _BastionCreate @@ -132,20 +133,27 @@ def _build_args(cert_file, private_key_file): return private_key + certificate -def ssh_bastion_host(cmd, auth_type, target_resource_id, resource_group_name, bastion_host_name, +def ssh_bastion_host(cmd, auth_type, target_resource_id, target_ip_address, resource_group_name, bastion_host_name, resource_port=None, username=None, ssh_key=None): import os + from .aaz.latest.network.bastion import Show _test_extension(SSH_EXTENSION_NAME) + bastion = Show(cli_ctx=cmd.cli_ctx)(command_args={ + "resource_group": resource_group_name, + "name": bastion_host_name + }) if not resource_port: resource_port = 22 - if not is_valid_resource_id(target_resource_id): - err_msg = "Please enter a valid resource ID. If this is not working, " \ - "try opening the JSON view of your resource (in the Overview tab), and copying the full resource ID." - raise InvalidArgumentValueError(err_msg) - tunnel_server = _get_tunnel(cmd, resource_group_name, bastion_host_name, target_resource_id, resource_port) + if bastion['sku']['name'] == BastionSku.Basic.value or bastion['sku']['name'] == BastionSku.Standard.value and bastion['enableTunneling'] is not True: + raise ClientRequestError('Bastion Host SKU must be Standard and Native Client must be enabled.') + + _validate_and_generate_resourceid(cmd, bastion, target_resource_id, target_ip_address) + bastion_endpoint = _get_bastion_endpoint(cmd, bastion, resource_port, target_resource_id) + + tunnel_server = _get_tunnel(cmd, bastion, bastion_endpoint, target_resource_id, resource_port) t = threading.Thread(target=_start_tunnel, args=(tunnel_server,)) t.daemon = True t.start() @@ -208,32 +216,33 @@ def _get_rdp_path(rdp_command="mstsc"): return rdp_path -def rdp_bastion_host(cmd, target_resource_id, resource_group_name, bastion_host_name, +def rdp_bastion_host(cmd, target_resource_id, target_ip_address, resource_group_name, bastion_host_name, resource_port=None, disable_gateway=False, configure=False, enable_mfa=False): import os from azure.cli.core._profile import Profile from ._process_helper import launch_and_wait - - if not resource_port: - resource_port = 3389 - if not is_valid_resource_id(target_resource_id): - err_msg = "Please enter a valid resource ID. If this is not working, " \ - "try opening the JSON view of your resource (in the Overview tab), and copying the full resource ID." - raise InvalidArgumentValueError(err_msg) - from .aaz.latest.network.bastion import Show + bastion = Show(cli_ctx=cmd.cli_ctx)(command_args={ "resource_group": resource_group_name, "name": bastion_host_name }) - if bastion['sku']['name'] == "Basic" or \ - bastion['sku']['name'] == "Standard" and bastion['enableTunneling'] is not True: + if not resource_port: + resource_port = 3389 + + if bastion['sku']['name'] == BastionSku.Basic.value or bastion['sku']['name'] == BastionSku.Standard.value and bastion['enableTunneling'] is not True: raise ClientRequestError('Bastion Host SKU must be Standard and Native Client must be enabled.') + ip_connect = _is_ipconnect_request(cmd, bastion, target_ip_address) + _validate_and_generate_resourceid(cmd, bastion, resource_group_name, target_resource_id, target_ip_address) + bastion_endpoint = _get_bastion_endpoint(cmd, bastion, resource_port, target_resource_id) + if platform.system() == "Windows": - if disable_gateway: - tunnel_server = _get_tunnel(cmd, resource_group_name, bastion_host_name, target_resource_id, resource_port) + if disable_gateway or ip_connect: + tunnel_server = _get_tunnel(cmd, bastion, bastion_endpoint, target_resource_id, resource_port, bastion_endpoint) + if ip_connect: + tunnel_server.set_host_name(target_ip_address) t = threading.Thread(target=_start_tunnel, args=(tunnel_server,)) t.daemon = True t.start() @@ -244,9 +253,8 @@ def rdp_bastion_host(cmd, target_resource_id, resource_group_name, bastion_host_ profile = Profile(cli_ctx=cmd.cli_ctx) access_token = profile.get_raw_token()[0][2].get("accessToken") logger.debug("Response %s", access_token) + web_address = f"https://{bastion_endpoint}/api/rdpfile?resourceId={target_resource_id}&format=rdp&rdpport={resource_port}&enablerdsaad={enable_mfa}" - web_address = f"https://{bastion['dnsName']}/api/rdpfile?resourceId={target_resource_id}" \ - f"&format=rdp&rdpport={resource_port}&enablerdsaad={enable_mfa}" headers = { "Authorization": f"Bearer {access_token}", "Accept": "*/*", @@ -259,7 +267,6 @@ def rdp_bastion_host(cmd, target_resource_id, resource_group_name, bastion_host_ raise ClientRequestError("Request to EncodingReservedUnitTypes v2 API endpoint failed.") _write_to_file(response) - rdpfilepath = os.getcwd() + "/conn.rdp" command = [_get_rdp_path()] if configure: @@ -270,6 +277,33 @@ def rdp_bastion_host(cmd, target_resource_id, resource_group_name, bastion_host_ raise UnrecognizedArgumentError("Platform is not supported for this command. Supported platforms: Windows") +def _is_ipconnect_request(cmd, bastion, target_ip_address): + if bastion['enableIpConnect'] is True and target_ip_address: + return True + + return False + + +def _validate_and_generate_resourceid(cmd, bastion, resource_group_name, target_resource_id, target_ip_address): + if target_ip_address: + if bastion['enableIpConnect'] is not True: + raise InvalidArgumentValueError("Bastion does not have IP Connect feature enabled, please enable and try again") + target_resource_id = f"/subscriptions/{get_subscription_id(cmd.cli_ctx)}/resourceGroups/{resource_group_name}/providers/Microsoft.Network/bh-hostConnect/{target_ip_address}" + elif not is_valid_resource_id(target_resource_id): + err_msg = "Please enter a valid resource ID. If this is not working, " \ + "try opening the JSON view of your resource (in the Overview tab), and copying the full resource ID." + raise InvalidArgumentValueError(err_msg) + + +def _get_bastion_endpoint(cmd, bastion, resource_port, target_resource_id): + if bastion['sku']['name'] == BastionSku.QuickConnect.value or bastion['sku']['name'] == BastionSku.Developer.value: + from .developer_sku_helper import (_get_data_pod) + bastion_endpoint = _get_data_pod(cmd, resource_port, target_resource_id, bastion) + return bastion_endpoint + + return bastion['dnsName'] + + def _write_to_file(response): with open("conn.rdp", "w", encoding="utf-8") as f: for line in response.text.splitlines(): @@ -277,17 +311,12 @@ def _write_to_file(response): f.write(line + "\n") -def _get_tunnel(cmd, resource_group_name, name, vm_id, resource_port, port=None): +def _get_tunnel(cmd, bastion, bastion_endpoint, vm_id, resource_port, port=None): from .tunnel import TunnelServer - from .aaz.latest.network.bastion import Show - bastion = Show(cli_ctx=cmd.cli_ctx)(command_args={ - "resource_group": resource_group_name, - "name": name - }) if port is None: port = 0 # will auto-select a free port from 1024-65535 - tunnel_server = TunnelServer(cmd.cli_ctx, "localhost", port, bastion, vm_id, resource_port) + tunnel_server = TunnelServer(cmd.cli_ctx, "localhost", port, bastion, bastion_endpoint, vm_id, resource_port) return tunnel_server @@ -303,12 +332,24 @@ def _tunnel_close_handler(tunnel): sys.exit() -def create_bastion_tunnel(cmd, target_resource_id, resource_group_name, bastion_host_name, resource_port, port, +def create_bastion_tunnel(cmd, target_resource_id, target_ip_address, resource_group_name, bastion_host_name, resource_port, port, timeout=None): if not is_valid_resource_id(target_resource_id): raise InvalidArgumentValueError("Please enter a valid VM resource ID.") - tunnel_server = _get_tunnel(cmd, resource_group_name, bastion_host_name, target_resource_id, resource_port, port) + from .aaz.latest.network.bastion import Show + bastion = Show(cli_ctx=cmd.cli_ctx)(command_args={ + "resource_group": resource_group_name, + "name": bastion_host_name + }) + + if bastion['sku']['name'] == BastionSku.Basic.value or bastion['sku']['name'] == BastionSku.Standard.value and bastion['enableTunneling'] is not True: + raise ClientRequestError('Bastion Host SKU must be Standard and Native Client must be enabled.') + + _validate_and_generate_resourceid(cmd, bastion, target_resource_id, target_ip_address) + bastion_endpoint = _get_bastion_endpoint(cmd, bastion, resource_port, target_resource_id) + + tunnel_server = _get_tunnel(cmd, bastion, bastion_endpoint, target_resource_id, resource_port, port) t = threading.Thread(target=_start_tunnel, args=(tunnel_server,)) t.daemon = True t.start() diff --git a/src/bastion/azext_bastion/developer_sku_helper.py b/src/bastion/azext_bastion/developer_sku_helper.py new file mode 100644 index 00000000000..f4dc92d07a7 --- /dev/null +++ b/src/bastion/azext_bastion/developer_sku_helper.py @@ -0,0 +1,29 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +# pylint: disable=import-error,unused-import + + +def _get_data_pod(cmd, resource_port, target_resource_id, bastion): + from azure.cli.core._profile import Profile + from azure.cli.core.util import should_disable_connection_verify + import requests + + profile = Profile(cli_ctx=cmd.cli_ctx) + auth_token, _, _ = profile.get_raw_token() + content = { + 'resourceId': target_resource_id, + 'bastionResourceId': bastion.id, + 'vmPort': resource_port, + 'azToken': auth_token[1], + 'connectionType': 'nativeclient' + } + headers = {'Content-Type': 'application/json'} + + web_address = f"https://{bastion['dnsName']}/api/connection" + response = requests.post(web_address, json=content, headers=headers, + verify=(not should_disable_connection_verify())) + + return response.content.decode("utf-8") diff --git a/src/bastion/azext_bastion/tunnel.py b/src/bastion/azext_bastion/tunnel.py index f998ad9ab11..e5a40bbf802 100644 --- a/src/bastion/azext_bastion/tunnel.py +++ b/src/bastion/azext_bastion/tunnel.py @@ -24,6 +24,7 @@ from websocket import create_connection, WebSocket from msrestazure.azure_exceptions import CloudError +from .BastionServiceConstants import BastionSku from azure.cli.core._profile import Profile from azure.cli.core.util import should_disable_connection_verify @@ -38,7 +39,7 @@ # pylint: disable=no-member,too-many-instance-attributes,bare-except,no-self-use class TunnelServer: - def __init__(self, cli_ctx, local_addr, local_port, bastion, remote_host, remote_port): + def __init__(self, cli_ctx, local_addr, local_port, bastion, bastion_endpoint, remote_host, remote_port): self.local_addr = local_addr self.local_port = int(local_port) if self.local_port != 0 and not self.is_port_open(): @@ -46,10 +47,12 @@ def __init__(self, cli_ctx, local_addr, local_port, bastion, remote_host, remote self.bastion = bastion self.remote_host = remote_host self.remote_port = remote_port + self.bastion_endpoint = bastion_endpoint self.client = None self.ws = None self.last_token = None self.node_id = None + self.host_name = None self.cli_ctx = cli_ctx logger.info('Creating a socket on port: %s', self.local_port) self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -83,12 +86,14 @@ def _get_auth_token(self): 'aztoken': auth_token[1], 'token': self.last_token, } + if self.host_name: + content['hostname'] = self.host_name if self.node_id: custom_header = {'X-Node-Id': self.node_id} else: custom_header = {} - web_address = f"https://{self.bastion['dnsName']}/api/tokens" + web_address = f"https://{self.bastion_endpoint}/api/tokens" response = requests.post(web_address, data=content, headers=custom_header, verify=(not should_disable_connection_verify())) response_json = None @@ -115,7 +120,11 @@ def _listen(self): self.client, _address = self.sock.accept() auth_token = self._get_auth_token() - host = f"wss://{self.bastion['dnsName']}/webtunnelv2/{auth_token}?X-Node-Id={self.node_id}" + if self.bastion['sku']['name'] == BastionSku.QuickConnect.name or self.bastion['sku']['name'] == BastionSku.Developer.name: + host = f"wss://{self.bastion_endpoint}/omni/webtunnel/{auth_token}" + else: + host = f"wss://{self.bastion_endpoint}/webtunnelv2/{auth_token}?X-Node-Id={self.node_id}" + verify_mode = ssl.CERT_NONE if should_disable_connection_verify() else ssl.CERT_REQUIRED self.ws = create_connection(host, sockopt=((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1),), @@ -192,7 +201,7 @@ def cleanup(self): else: custom_header = {} - web_address = f"https://{self.bastion['dnsName']}/api/tokens/{self.last_token}" + web_address = f"https://{self.bastion_endpoint}/api/tokens/{self.last_token}" response = requests.delete(web_address, headers=custom_header, verify=(not should_disable_connection_verify())) if response.status_code == 404: @@ -208,3 +217,6 @@ def cleanup(self): def get_port(self): return self.local_port + + def set_host_name(self, hostname): + self.host_name = hostname