From 66ff916cf709c1058be667dce599743e6e236a3f Mon Sep 17 00:00:00 2001 From: Aakash Radhakrishnan Date: Thu, 19 Jan 2023 22:16:19 -0800 Subject: [PATCH 1/7] Initial changes for ipconnect and quickconnect --- src/bastion/azext_bastion/_params.py | 5 +- src/bastion/azext_bastion/_validators.py | 27 ++++++ src/bastion/azext_bastion/custom.py | 95 +++++++++++++------ .../azext_bastion/developer_sku_helper.py | 27 ++++++ src/bastion/azext_bastion/tunnel.py | 15 ++- 5 files changed, 133 insertions(+), 36 deletions(-) create mode 100644 src/bastion/azext_bastion/_validators.py create mode 100644 src/bastion/azext_bastion/developer_sku_helper.py diff --git a/src/bastion/azext_bastion/_params.py b/src/bastion/azext_bastion/_params.py index 862ff05e65b..565bf887d25 100644 --- a/src/bastion/azext_bastion/_params.py +++ b/src/bastion/azext_bastion/_params.py @@ -10,6 +10,7 @@ from azure.cli.core.commands.parameters import get_resource_name_completion_list, get_three_state_flag from knack.arguments import CLIArgumentType +from ._validators import (validate_ip_address) def load_arguments(self, _): # pylint: disable=unused-argument @@ -24,8 +25,10 @@ def load_arguments(self, _): # pylint: disable=unused-argument c.argument("bastion_host_name", bastion_host_name_type, options_list=["--name", "-n"]) c.argument("resource_port", help="Resource port of the target VM to which the bastion will connect.", options_list=["--resource-port"]) - c.argument("target_resource_id", help="ResourceId of the target Virtual Machine.", + c.argument("target_resource_id", help="ResourceId of the target Virtual Machine.", required=False, options_list=["--target-resource-id"]) + c.argument("target_ip_address", help="IP address of target Virtual Machine.", required=False, + options_list=["--target-ip-address"], validator=validate_ip_address) with self.argument_context("network bastion ssh") as c: c.argument("auth_type", help="Auth type to use for SSH connections.", options_list=["--auth-type"]) diff --git a/src/bastion/azext_bastion/_validators.py b/src/bastion/azext_bastion/_validators.py new file mode 100644 index 00000000000..a18e65e1f2a --- /dev/null +++ b/src/bastion/azext_bastion/_validators.py @@ -0,0 +1,27 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import ipaddress +from azure.cli.core.azclierror import ValidationError, InvalidArgumentValueError, RequiredArgumentMissingError, \ + UnrecognizedArgumentError, CLIInternalError, ClientRequestError + +def validate_ip_address(cmd, namespace): + if namespace.target_ip_address is not None: + _validate_ip_address_format(namespace) + +def _validate_ip_address_format(namespace): + if namespace.target_ip_address is not None: + input_value = namespace.target_ip_address + if ' ' in input_value: + raise InvalidArgumentValueError("Spaces not allowed: '{}' ".format(input_value)) + input_ips = input_value.split(',') + if len(input_ips) > 8: + raise InvalidArgumentValueError('Maximum 8 IP addresses are allowed per rule.') + validated_ips = '' + for ip in input_ips: + # Use ipaddress library to validate ip network format + ip_obj = ipaddress.ip_network(ip) + validated_ips += str(ip_obj) + ',' + namespace.target_ip_address = validated_ips[:-1] \ No newline at end of file diff --git a/src/bastion/azext_bastion/custom.py b/src/bastion/azext_bastion/custom.py index a3ae521b72f..69f2055ba10 100644 --- a/src/bastion/azext_bastion/custom.py +++ b/src/bastion/azext_bastion/custom.py @@ -135,17 +135,20 @@ def _build_args(cert_file, private_key_file): def ssh_bastion_host(cmd, auth_type, target_resource_id, resource_group_name, bastion_host_name, resource_port=None, username=None, ssh_key=None): import os + from .aaz.latest.network.bastion import Show _test_extension(SSH_EXTENSION_NAME) if not resource_port: resource_port = 22 - if not is_valid_resource_id(target_resource_id): - err_msg = "Please enter a valid resource ID. If this is not working, " \ - "try opening the JSON view of your resource (in the Overview tab), and copying the full resource ID." - raise InvalidArgumentValueError(err_msg) + + if bastion['sku']['name'] == "Basic" or bastion['sku']['name'] == "Standard" and bastion['enableTunneling'] != True: + raise ClientRequestError('Bastion Host SKU must be Standard and Native Client must be enabled.') + + validate_and_generate_resourceid(cmd, bastion, target_resource_id, target_ip_address) + bastion_endpoint = _get_bastion_endpoint(cmd, bastion, resource_port, target_resource_id) - tunnel_server = _get_tunnel(cmd, resource_group_name, bastion_host_name, target_resource_id, resource_port) + tunnel_server = _get_tunnel(cmd, bastion, bastion_endpoint, target_resource_id, resource_port) t = threading.Thread(target=_start_tunnel, args=(tunnel_server,)) t.daemon = True t.start() @@ -208,32 +211,34 @@ def _get_rdp_path(rdp_command="mstsc"): return rdp_path -def rdp_bastion_host(cmd, target_resource_id, resource_group_name, bastion_host_name, +def rdp_bastion_host(cmd, target_resource_id, target_ip_address, resource_group_name, bastion_host_name, resource_port=None, disable_gateway=False, configure=False, enable_mfa=False): import os from azure.cli.core._profile import Profile from ._process_helper import launch_and_wait - - if not resource_port: - resource_port = 3389 - if not is_valid_resource_id(target_resource_id): - err_msg = "Please enter a valid resource ID. If this is not working, " \ - "try opening the JSON view of your resource (in the Overview tab), and copying the full resource ID." - raise InvalidArgumentValueError(err_msg) - + from azure.cli.core.commands.client_factory import get_subscription_id from .aaz.latest.network.bastion import Show + bastion = Show(cli_ctx=cmd.cli_ctx)(command_args={ "resource_group": resource_group_name, "name": bastion_host_name }) - if bastion['sku']['name'] == "Basic" or \ - bastion['sku']['name'] == "Standard" and bastion['enableTunneling'] is not True: + if not resource_port: + resource_port = 3389 + + if bastion['sku']['name'] == "Basic" or bastion['sku']['name'] == "Standard" and bastion['enableTunneling'] != True: raise ClientRequestError('Bastion Host SKU must be Standard and Native Client must be enabled.') + ip_connect = _is_ipconnect_request(cmd, bastion, target_ip_address) + _validate_and_generate_resourceid(cmd, bastion, resource_group_name, target_resource_id, target_ip_address) + bastion_endpoint = _get_bastion_endpoint(cmd, bastion, resource_port, target_resource_id) + if platform.system() == "Windows": - if disable_gateway: - tunnel_server = _get_tunnel(cmd, resource_group_name, bastion_host_name, target_resource_id, resource_port) + if disable_gateway or ip_connect: + tunnel_server = _get_tunnel(cmd, bastion, bastion_endpoint, target_resource_id, resource_port, bastion_endpoint) + if ip_connect: + tunnel_server.set_host_name(target_ip_address) t = threading.Thread(target=_start_tunnel, args=(tunnel_server,)) t.daemon = True t.start() @@ -244,9 +249,8 @@ def rdp_bastion_host(cmd, target_resource_id, resource_group_name, bastion_host_ profile = Profile(cli_ctx=cmd.cli_ctx) access_token = profile.get_raw_token()[0][2].get("accessToken") logger.debug("Response %s", access_token) - - web_address = f"https://{bastion['dnsName']}/api/rdpfile?resourceId={target_resource_id}" \ - f"&format=rdp&rdpport={resource_port}&enablerdsaad={enable_mfa}" + + web_address = f"https://{bastion_endpoint}/api/rdpfile?resourceId={target_resource_id}&format=rdp&rdpport={resource_port}&enablerdsaad={enable_mfa}" headers = { "Authorization": f"Bearer {access_token}", "Accept": "*/*", @@ -269,6 +273,29 @@ def rdp_bastion_host(cmd, target_resource_id, resource_group_name, bastion_host_ else: raise UnrecognizedArgumentError("Platform is not supported for this command. Supported platforms: Windows") +def _is_ipconnect_request(cmd, bastion, target_ip_address): + if bastion['enableIpConnect'] == True and target_ip_address: + return True + + return False + +def _validate_and_generate_resourceid(cmd, bastion, resource_group_name, target_resource_id, target_ip_address): + if target_ip_address: + if bastion['enableIpConnect'] != True: + raise InvalidArgumentValueError("Bastion does not have IP Connect feature enabled, please enable and try again") + target_resource_id = f"/subscriptions/{get_subscription_id(cmd.cli_ctx)}/resourceGroups/{resource_group_name}/providers/Microsoft.Network/bh-hostConnect/{target_ip_address}" + elif not is_valid_resource_id(target_resource_id): + err_msg = "Please enter a valid resource ID. If this is not working, " \ + "try opening the JSON view of your resource (in the Overview tab), and copying the full resource ID." + raise InvalidArgumentValueError(err_msg) + +def _get_bastion_endpoint(cmd, bastion, resource_port, target_resource_id): + if bastion['sku']['name'] == "QuickConnect" or bastion['sku']['name'] == "Developer": + from .developer_sku_helper import (_get_data_pod) + bastion_endpoint = _get_data_pod(cmd, resource_port, target_resource_id, bastion) + return bastion_endpoint + + return bastion['dnsName'] def _write_to_file(response): with open("conn.rdp", "w", encoding="utf-8") as f: @@ -277,17 +304,12 @@ def _write_to_file(response): f.write(line + "\n") -def _get_tunnel(cmd, resource_group_name, name, vm_id, resource_port, port=None): +def _get_tunnel(cmd, bastion, bastion_endpoint, vm_id, resource_port, port=None): from .tunnel import TunnelServer - from .aaz.latest.network.bastion import Show - - bastion = Show(cli_ctx=cmd.cli_ctx)(command_args={ - "resource_group": resource_group_name, - "name": name - }) + if port is None: port = 0 # will auto-select a free port from 1024-65535 - tunnel_server = TunnelServer(cmd.cli_ctx, "localhost", port, bastion, vm_id, resource_port) + tunnel_server = TunnelServer(cmd.cli_ctx, "localhost", port, bastion, bastion_endpoint, vm_id, resource_port) return tunnel_server @@ -295,7 +317,6 @@ def _get_tunnel(cmd, resource_group_name, name, vm_id, resource_port, port=None) def _start_tunnel(tunnel_server): tunnel_server.start_server() - def _tunnel_close_handler(tunnel): logger.info("Ctrl + C received. Clean up and then exit.") tunnel.cleanup() @@ -308,7 +329,19 @@ def create_bastion_tunnel(cmd, target_resource_id, resource_group_name, bastion_ if not is_valid_resource_id(target_resource_id): raise InvalidArgumentValueError("Please enter a valid VM resource ID.") - tunnel_server = _get_tunnel(cmd, resource_group_name, bastion_host_name, target_resource_id, resource_port, port) + from .aaz.latest.network.bastion import Show + bastion = Show(cli_ctx=cmd.cli_ctx)(command_args={ + "resource_group": resource_group_name, + "name": bastion_host_name + }) + + if bastion['sku']['name'] == "Basic" or bastion['sku']['name'] == "Standard" and bastion['enableTunneling'] != True: + raise ClientRequestError('Bastion Host SKU must be Standard and Native Client must be enabled.') + + validate_and_generate_resourceid(cmd, bastion, target_resource_id, target_ip_address) + bastion_endpoint = _get_bastion_endpoint(cmd, bastion, resource_port, target_resource_id) + + tunnel_server = _get_tunnel(cmd, bastion, bastion_endpoint, target_resource_id, resource_port, port) t = threading.Thread(target=_start_tunnel, args=(tunnel_server,)) t.daemon = True t.start() diff --git a/src/bastion/azext_bastion/developer_sku_helper.py b/src/bastion/azext_bastion/developer_sku_helper.py new file mode 100644 index 00000000000..04d694208a0 --- /dev/null +++ b/src/bastion/azext_bastion/developer_sku_helper.py @@ -0,0 +1,27 @@ +def _get_data_pod(cmd, resource_port, target_resource_id, bastion): + from azure.cli.core._profile import Profile + from azure.cli.core.util import should_disable_connection_verify + import json + import requests + + + profile = Profile(cli_ctx=cmd.cli_ctx) + auth_token, _, _ = profile.get_raw_token() + content = { + 'resourceId': target_resource_id, + 'bastionResourceId': bastion.id, + 'vmPort': resource_port, + 'azToken': auth_token[1], + 'connectionType' : 'nativeclient' + } + headers = { + 'Content-Type': 'application/json', + } + + web_address = 'https://{}/api/connection'.format(bastion.dns_name) + response = requests.post(web_address, json=content, headers=headers, verify=(not should_disable_connection_verify())) + response_json = None + + return response.content.decode("utf-8") + + \ No newline at end of file diff --git a/src/bastion/azext_bastion/tunnel.py b/src/bastion/azext_bastion/tunnel.py index f998ad9ab11..4b457d0f839 100644 --- a/src/bastion/azext_bastion/tunnel.py +++ b/src/bastion/azext_bastion/tunnel.py @@ -38,7 +38,7 @@ # pylint: disable=no-member,too-many-instance-attributes,bare-except,no-self-use class TunnelServer: - def __init__(self, cli_ctx, local_addr, local_port, bastion, remote_host, remote_port): + def __init__(self, cli_ctx, local_addr, local_port, bastion, bastion_endpoint, remote_host, remote_port): self.local_addr = local_addr self.local_port = int(local_port) if self.local_port != 0 and not self.is_port_open(): @@ -46,10 +46,12 @@ def __init__(self, cli_ctx, local_addr, local_port, bastion, remote_host, remote self.bastion = bastion self.remote_host = remote_host self.remote_port = remote_port + self.bastion_endpoint = bastion_endpoint self.client = None self.ws = None self.last_token = None self.node_id = None + self.host_name = None self.cli_ctx = cli_ctx logger.info('Creating a socket on port: %s', self.local_port) self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -83,12 +85,14 @@ def _get_auth_token(self): 'aztoken': auth_token[1], 'token': self.last_token, } + if self.host_name: + content['hostname'] = self.host_name if self.node_id: custom_header = {'X-Node-Id': self.node_id} else: custom_header = {} - web_address = f"https://{self.bastion['dnsName']}/api/tokens" + web_address = f"https://{self.bastion_endpoint}/api/tokens" response = requests.post(web_address, data=content, headers=custom_header, verify=(not should_disable_connection_verify())) response_json = None @@ -115,7 +119,7 @@ def _listen(self): self.client, _address = self.sock.accept() auth_token = self._get_auth_token() - host = f"wss://{self.bastion['dnsName']}/webtunnelv2/{auth_token}?X-Node-Id={self.node_id}" + host = f"wss://{self.bastion_endpoint}/webtunnelv2/{auth_token}?X-Node-Id={self.node_id}" verify_mode = ssl.CERT_NONE if should_disable_connection_verify() else ssl.CERT_REQUIRED self.ws = create_connection(host, sockopt=((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1),), @@ -192,7 +196,7 @@ def cleanup(self): else: custom_header = {} - web_address = f"https://{self.bastion['dnsName']}/api/tokens/{self.last_token}" + web_address = f"https://{self.bastion_endpoint}/api/tokens/{self.last_token}" response = requests.delete(web_address, headers=custom_header, verify=(not should_disable_connection_verify())) if response.status_code == 404: @@ -208,3 +212,6 @@ def cleanup(self): def get_port(self): return self.local_port + + def set_host_name(self, hostname): + self.host_name = hostname From ffc32686c308d16370cd5f52c86b18314ec2b41a Mon Sep 17 00:00:00 2001 From: Aakash Radhakrishnan Date: Fri, 20 Jan 2023 10:36:56 -0800 Subject: [PATCH 2/7] adding endpoints for quickconnect/developer sku --- src/bastion/azext_bastion/custom.py | 8 ++++++-- src/bastion/azext_bastion/tunnel.py | 6 +++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/bastion/azext_bastion/custom.py b/src/bastion/azext_bastion/custom.py index 69f2055ba10..9ac1554a56f 100644 --- a/src/bastion/azext_bastion/custom.py +++ b/src/bastion/azext_bastion/custom.py @@ -250,7 +250,11 @@ def rdp_bastion_host(cmd, target_resource_id, target_ip_address, resource_group_ access_token = profile.get_raw_token()[0][2].get("accessToken") logger.debug("Response %s", access_token) - web_address = f"https://{bastion_endpoint}/api/rdpfile?resourceId={target_resource_id}&format=rdp&rdpport={resource_port}&enablerdsaad={enable_mfa}" + if bastion['sku']['name'] == "QuickConnect" or bastion['sku']['name'] == "Developer": + web_address = f"https://{bastion_endpoint}/api/omni/rdpfile?resourceId={target_resource_id}&format=rdp&rdpport={resource_port}&enablerdsaad={enable_mfa}" + else: + web_address = f"https://{bastion_endpoint}/api/rdpfile?resourceId={target_resource_id}&format=rdp&rdpport={resource_port}&enablerdsaad={enable_mfa}" + headers = { "Authorization": f"Bearer {access_token}", "Accept": "*/*", @@ -334,7 +338,7 @@ def create_bastion_tunnel(cmd, target_resource_id, resource_group_name, bastion_ "resource_group": resource_group_name, "name": bastion_host_name }) - + if bastion['sku']['name'] == "Basic" or bastion['sku']['name'] == "Standard" and bastion['enableTunneling'] != True: raise ClientRequestError('Bastion Host SKU must be Standard and Native Client must be enabled.') diff --git a/src/bastion/azext_bastion/tunnel.py b/src/bastion/azext_bastion/tunnel.py index 4b457d0f839..857dbfea131 100644 --- a/src/bastion/azext_bastion/tunnel.py +++ b/src/bastion/azext_bastion/tunnel.py @@ -119,7 +119,11 @@ def _listen(self): self.client, _address = self.sock.accept() auth_token = self._get_auth_token() - host = f"wss://{self.bastion_endpoint}/webtunnelv2/{auth_token}?X-Node-Id={self.node_id}" + if self.bastion['sku']['name'] == "QuickConnect" or bastion['sku']['name'] == "Developer": + host = f"wss://{self.bastion_endpoint}/omni/webtunnel/{auth_token}" + else: + host = f"wss://{self.bastion_endpoint}/webtunnelv2/{auth_token}?X-Node-Id={self.node_id}" + verify_mode = ssl.CERT_NONE if should_disable_connection_verify() else ssl.CERT_REQUIRED self.ws = create_connection(host, sockopt=((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1),), From 5696bfcfc532ee4fd7c91590b1346ded14cb1f1d Mon Sep 17 00:00:00 2001 From: Aakash Radhakrishnan Date: Tue, 24 Jan 2023 11:25:38 -0800 Subject: [PATCH 3/7] stlying fixes --- src/bastion/azext_bastion/_validators.py | 9 ++-- src/bastion/azext_bastion/custom.py | 47 +++++++++++-------- .../azext_bastion/developer_sku_helper.py | 16 ++----- src/bastion/azext_bastion/tunnel.py | 4 +- 4 files changed, 39 insertions(+), 37 deletions(-) diff --git a/src/bastion/azext_bastion/_validators.py b/src/bastion/azext_bastion/_validators.py index a18e65e1f2a..5e368869e03 100644 --- a/src/bastion/azext_bastion/_validators.py +++ b/src/bastion/azext_bastion/_validators.py @@ -4,13 +4,14 @@ # -------------------------------------------------------------------------------------------- import ipaddress -from azure.cli.core.azclierror import ValidationError, InvalidArgumentValueError, RequiredArgumentMissingError, \ - UnrecognizedArgumentError, CLIInternalError, ClientRequestError +from azure.cli.core.azclierror import InvalidArgumentValueError -def validate_ip_address(cmd, namespace): + +def validate_ip_address(namespace): if namespace.target_ip_address is not None: _validate_ip_address_format(namespace) + def _validate_ip_address_format(namespace): if namespace.target_ip_address is not None: input_value = namespace.target_ip_address @@ -24,4 +25,4 @@ def _validate_ip_address_format(namespace): # Use ipaddress library to validate ip network format ip_obj = ipaddress.ip_network(ip) validated_ips += str(ip_obj) + ',' - namespace.target_ip_address = validated_ips[:-1] \ No newline at end of file + namespace.target_ip_address = validated_ips[:-1] diff --git a/src/bastion/azext_bastion/custom.py b/src/bastion/azext_bastion/custom.py index 9ac1554a56f..6026478f844 100644 --- a/src/bastion/azext_bastion/custom.py +++ b/src/bastion/azext_bastion/custom.py @@ -18,6 +18,7 @@ import requests from azure.cli.core.azclierror import ValidationError, InvalidArgumentValueError, RequiredArgumentMissingError, \ UnrecognizedArgumentError, CLIInternalError, ClientRequestError +from azure.cli.core.commands.client_factory import get_subscription_id from knack.log import get_logger from msrestazure.tools import is_valid_resource_id @@ -132,20 +133,24 @@ def _build_args(cert_file, private_key_file): return private_key + certificate -def ssh_bastion_host(cmd, auth_type, target_resource_id, resource_group_name, bastion_host_name, +def ssh_bastion_host(cmd, auth_type, target_resource_id, target_ip_address, resource_group_name, bastion_host_name, resource_port=None, username=None, ssh_key=None): import os from .aaz.latest.network.bastion import Show _test_extension(SSH_EXTENSION_NAME) + bastion = Show(cli_ctx=cmd.cli_ctx)(command_args={ + "resource_group": resource_group_name, + "name": bastion_host_name + }) if not resource_port: resource_port = 22 - - if bastion['sku']['name'] == "Basic" or bastion['sku']['name'] == "Standard" and bastion['enableTunneling'] != True: + + if bastion['sku']['name'] == "Basic" or bastion['sku']['name'] == "Standard" and bastion['enableTunneling'] is not True: raise ClientRequestError('Bastion Host SKU must be Standard and Native Client must be enabled.') - validate_and_generate_resourceid(cmd, bastion, target_resource_id, target_ip_address) + _validate_and_generate_resourceid(cmd, bastion, target_resource_id, target_ip_address) bastion_endpoint = _get_bastion_endpoint(cmd, bastion, resource_port, target_resource_id) tunnel_server = _get_tunnel(cmd, bastion, bastion_endpoint, target_resource_id, resource_port) @@ -216,7 +221,6 @@ def rdp_bastion_host(cmd, target_resource_id, target_ip_address, resource_group_ import os from azure.cli.core._profile import Profile from ._process_helper import launch_and_wait - from azure.cli.core.commands.client_factory import get_subscription_id from .aaz.latest.network.bastion import Show bastion = Show(cli_ctx=cmd.cli_ctx)(command_args={ @@ -226,8 +230,8 @@ def rdp_bastion_host(cmd, target_resource_id, target_ip_address, resource_group_ if not resource_port: resource_port = 3389 - - if bastion['sku']['name'] == "Basic" or bastion['sku']['name'] == "Standard" and bastion['enableTunneling'] != True: + + if bastion['sku']['name'] == "Basic" or bastion['sku']['name'] == "Standard" and bastion['enableTunneling'] is not True: raise ClientRequestError('Bastion Host SKU must be Standard and Native Client must be enabled.') ip_connect = _is_ipconnect_request(cmd, bastion, target_ip_address) @@ -249,8 +253,7 @@ def rdp_bastion_host(cmd, target_resource_id, target_ip_address, resource_group_ profile = Profile(cli_ctx=cmd.cli_ctx) access_token = profile.get_raw_token()[0][2].get("accessToken") logger.debug("Response %s", access_token) - - if bastion['sku']['name'] == "QuickConnect" or bastion['sku']['name'] == "Developer": + if bastion['sku']['name'] == "QuickConnect" or bastion['sku']['name'] == "Developer": web_address = f"https://{bastion_endpoint}/api/omni/rdpfile?resourceId={target_resource_id}&format=rdp&rdpport={resource_port}&enablerdsaad={enable_mfa}" else: web_address = f"https://{bastion_endpoint}/api/rdpfile?resourceId={target_resource_id}&format=rdp&rdpport={resource_port}&enablerdsaad={enable_mfa}" @@ -267,7 +270,6 @@ def rdp_bastion_host(cmd, target_resource_id, target_ip_address, resource_group_ raise ClientRequestError("Request to EncodingReservedUnitTypes v2 API endpoint failed.") _write_to_file(response) - rdpfilepath = os.getcwd() + "/conn.rdp" command = [_get_rdp_path()] if configure: @@ -277,30 +279,34 @@ def rdp_bastion_host(cmd, target_resource_id, target_ip_address, resource_group_ else: raise UnrecognizedArgumentError("Platform is not supported for this command. Supported platforms: Windows") + def _is_ipconnect_request(cmd, bastion, target_ip_address): - if bastion['enableIpConnect'] == True and target_ip_address: + if bastion['enableIpConnect'] is True and target_ip_address: return True - + return False + def _validate_and_generate_resourceid(cmd, bastion, resource_group_name, target_resource_id, target_ip_address): if target_ip_address: - if bastion['enableIpConnect'] != True: + if bastion['enableIpConnect'] is not True: raise InvalidArgumentValueError("Bastion does not have IP Connect feature enabled, please enable and try again") target_resource_id = f"/subscriptions/{get_subscription_id(cmd.cli_ctx)}/resourceGroups/{resource_group_name}/providers/Microsoft.Network/bh-hostConnect/{target_ip_address}" elif not is_valid_resource_id(target_resource_id): err_msg = "Please enter a valid resource ID. If this is not working, " \ "try opening the JSON view of your resource (in the Overview tab), and copying the full resource ID." raise InvalidArgumentValueError(err_msg) - + + def _get_bastion_endpoint(cmd, bastion, resource_port, target_resource_id): - if bastion['sku']['name'] == "QuickConnect" or bastion['sku']['name'] == "Developer": + if bastion['sku']['name'] == "QuickConnect" or bastion['sku']['name'] == "Developer": from .developer_sku_helper import (_get_data_pod) bastion_endpoint = _get_data_pod(cmd, resource_port, target_resource_id, bastion) return bastion_endpoint return bastion['dnsName'] + def _write_to_file(response): with open("conn.rdp", "w", encoding="utf-8") as f: for line in response.text.splitlines(): @@ -310,7 +316,7 @@ def _write_to_file(response): def _get_tunnel(cmd, bastion, bastion_endpoint, vm_id, resource_port, port=None): from .tunnel import TunnelServer - + if port is None: port = 0 # will auto-select a free port from 1024-65535 tunnel_server = TunnelServer(cmd.cli_ctx, "localhost", port, bastion, bastion_endpoint, vm_id, resource_port) @@ -321,6 +327,7 @@ def _get_tunnel(cmd, bastion, bastion_endpoint, vm_id, resource_port, port=None) def _start_tunnel(tunnel_server): tunnel_server.start_server() + def _tunnel_close_handler(tunnel): logger.info("Ctrl + C received. Clean up and then exit.") tunnel.cleanup() @@ -328,7 +335,7 @@ def _tunnel_close_handler(tunnel): sys.exit() -def create_bastion_tunnel(cmd, target_resource_id, resource_group_name, bastion_host_name, resource_port, port, +def create_bastion_tunnel(cmd, target_resource_id, target_ip_address, resource_group_name, bastion_host_name, resource_port, port, timeout=None): if not is_valid_resource_id(target_resource_id): raise InvalidArgumentValueError("Please enter a valid VM resource ID.") @@ -338,11 +345,11 @@ def create_bastion_tunnel(cmd, target_resource_id, resource_group_name, bastion_ "resource_group": resource_group_name, "name": bastion_host_name }) - - if bastion['sku']['name'] == "Basic" or bastion['sku']['name'] == "Standard" and bastion['enableTunneling'] != True: + + if bastion['sku']['name'] == "Basic" or bastion['sku']['name'] == "Standard" and bastion['enableTunneling'] is not True: raise ClientRequestError('Bastion Host SKU must be Standard and Native Client must be enabled.') - validate_and_generate_resourceid(cmd, bastion, target_resource_id, target_ip_address) + _validate_and_generate_resourceid(cmd, bastion, target_resource_id, target_ip_address) bastion_endpoint = _get_bastion_endpoint(cmd, bastion, resource_port, target_resource_id) tunnel_server = _get_tunnel(cmd, bastion, bastion_endpoint, target_resource_id, resource_port, port) diff --git a/src/bastion/azext_bastion/developer_sku_helper.py b/src/bastion/azext_bastion/developer_sku_helper.py index 04d694208a0..bd3e7940ea7 100644 --- a/src/bastion/azext_bastion/developer_sku_helper.py +++ b/src/bastion/azext_bastion/developer_sku_helper.py @@ -1,10 +1,8 @@ def _get_data_pod(cmd, resource_port, target_resource_id, bastion): from azure.cli.core._profile import Profile from azure.cli.core.util import should_disable_connection_verify - import json import requests - profile = Profile(cli_ctx=cmd.cli_ctx) auth_token, _, _ = profile.get_raw_token() content = { @@ -12,16 +10,12 @@ def _get_data_pod(cmd, resource_port, target_resource_id, bastion): 'bastionResourceId': bastion.id, 'vmPort': resource_port, 'azToken': auth_token[1], - 'connectionType' : 'nativeclient' + 'connectionType': 'nativeclient' } - headers = { - 'Content-Type': 'application/json', - } + headers = {'Content-Type': 'application/json'} - web_address = 'https://{}/api/connection'.format(bastion.dns_name) - response = requests.post(web_address, json=content, headers=headers, verify=(not should_disable_connection_verify())) - response_json = None + web_address = f"https://{bastion['dnsName']}/api/connection" + response = requests.post(web_address, json=content, headers=headers, + verify=(not should_disable_connection_verify())) return response.content.decode("utf-8") - - \ No newline at end of file diff --git a/src/bastion/azext_bastion/tunnel.py b/src/bastion/azext_bastion/tunnel.py index 857dbfea131..01939b28770 100644 --- a/src/bastion/azext_bastion/tunnel.py +++ b/src/bastion/azext_bastion/tunnel.py @@ -119,7 +119,7 @@ def _listen(self): self.client, _address = self.sock.accept() auth_token = self._get_auth_token() - if self.bastion['sku']['name'] == "QuickConnect" or bastion['sku']['name'] == "Developer": + if self.bastion['sku']['name'] == "QuickConnect" or self.bastion['sku']['name'] == "Developer": host = f"wss://{self.bastion_endpoint}/omni/webtunnel/{auth_token}" else: host = f"wss://{self.bastion_endpoint}/webtunnelv2/{auth_token}?X-Node-Id={self.node_id}" @@ -216,6 +216,6 @@ def cleanup(self): def get_port(self): return self.local_port - + def set_host_name(self, hostname): self.host_name = hostname From db01bd825478bab1bd8a3529478668e952de473b Mon Sep 17 00:00:00 2001 From: Aakash Radhakrishnan Date: Tue, 24 Jan 2023 12:06:25 -0800 Subject: [PATCH 4/7] adding constant file with enum for skus --- .../azext_bastion/BastionServiceConstants.py | 14 ++++++++++++++ src/bastion/azext_bastion/custom.py | 12 ++++++------ src/bastion/azext_bastion/developer_sku_helper.py | 7 +++++++ src/bastion/azext_bastion/tunnel.py | 3 ++- 4 files changed, 29 insertions(+), 7 deletions(-) create mode 100644 src/bastion/azext_bastion/BastionServiceConstants.py diff --git a/src/bastion/azext_bastion/BastionServiceConstants.py b/src/bastion/azext_bastion/BastionServiceConstants.py new file mode 100644 index 00000000000..940d06b75ca --- /dev/null +++ b/src/bastion/azext_bastion/BastionServiceConstants.py @@ -0,0 +1,14 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +# pylint: disable=import-error,unused-import + +from enum import Enum + +class BastionSku(Enum): + Basic = 1 + Standard = 2 + Developer = 3 + QuickConnect = 4 \ No newline at end of file diff --git a/src/bastion/azext_bastion/custom.py b/src/bastion/azext_bastion/custom.py index 6026478f844..81f654a77ae 100644 --- a/src/bastion/azext_bastion/custom.py +++ b/src/bastion/azext_bastion/custom.py @@ -21,7 +21,7 @@ from azure.cli.core.commands.client_factory import get_subscription_id from knack.log import get_logger from msrestazure.tools import is_valid_resource_id - +from .BastionServiceConstants import BastionSku from .aaz.latest.network.bastion import Create as _BastionCreate @@ -147,7 +147,7 @@ def ssh_bastion_host(cmd, auth_type, target_resource_id, target_ip_address, reso if not resource_port: resource_port = 22 - if bastion['sku']['name'] == "Basic" or bastion['sku']['name'] == "Standard" and bastion['enableTunneling'] is not True: + if bastion['sku']['name'] == BastionSku.Basic.name or bastion['sku']['name'] == BastionSku.Standard.name and bastion['enableTunneling'] is not True: raise ClientRequestError('Bastion Host SKU must be Standard and Native Client must be enabled.') _validate_and_generate_resourceid(cmd, bastion, target_resource_id, target_ip_address) @@ -231,7 +231,7 @@ def rdp_bastion_host(cmd, target_resource_id, target_ip_address, resource_group_ if not resource_port: resource_port = 3389 - if bastion['sku']['name'] == "Basic" or bastion['sku']['name'] == "Standard" and bastion['enableTunneling'] is not True: + if bastion['sku']['name'] == BastionSku.Basic.name or bastion['sku']['name'] == BastionSku.Standard.name and bastion['enableTunneling'] is not True: raise ClientRequestError('Bastion Host SKU must be Standard and Native Client must be enabled.') ip_connect = _is_ipconnect_request(cmd, bastion, target_ip_address) @@ -253,7 +253,7 @@ def rdp_bastion_host(cmd, target_resource_id, target_ip_address, resource_group_ profile = Profile(cli_ctx=cmd.cli_ctx) access_token = profile.get_raw_token()[0][2].get("accessToken") logger.debug("Response %s", access_token) - if bastion['sku']['name'] == "QuickConnect" or bastion['sku']['name'] == "Developer": + if bastion['sku']['name'] == BastionSku.QuickConnect.name or bastion['sku']['name'] == BastionSku.Developer.name: web_address = f"https://{bastion_endpoint}/api/omni/rdpfile?resourceId={target_resource_id}&format=rdp&rdpport={resource_port}&enablerdsaad={enable_mfa}" else: web_address = f"https://{bastion_endpoint}/api/rdpfile?resourceId={target_resource_id}&format=rdp&rdpport={resource_port}&enablerdsaad={enable_mfa}" @@ -299,7 +299,7 @@ def _validate_and_generate_resourceid(cmd, bastion, resource_group_name, target_ def _get_bastion_endpoint(cmd, bastion, resource_port, target_resource_id): - if bastion['sku']['name'] == "QuickConnect" or bastion['sku']['name'] == "Developer": + if bastion['sku']['name'] == BastionSku.QuickConnect.name or bastion['sku']['name'] == BastionSku.Developer.name: from .developer_sku_helper import (_get_data_pod) bastion_endpoint = _get_data_pod(cmd, resource_port, target_resource_id, bastion) return bastion_endpoint @@ -346,7 +346,7 @@ def create_bastion_tunnel(cmd, target_resource_id, target_ip_address, resource_g "name": bastion_host_name }) - if bastion['sku']['name'] == "Basic" or bastion['sku']['name'] == "Standard" and bastion['enableTunneling'] is not True: + if bastion['sku']['name'] == BastionSku.Basic.name or bastion['sku']['name'] == BastionSku.Standard.name and bastion['enableTunneling'] is not True: raise ClientRequestError('Bastion Host SKU must be Standard and Native Client must be enabled.') _validate_and_generate_resourceid(cmd, bastion, target_resource_id, target_ip_address) diff --git a/src/bastion/azext_bastion/developer_sku_helper.py b/src/bastion/azext_bastion/developer_sku_helper.py index bd3e7940ea7..874d8c1472c 100644 --- a/src/bastion/azext_bastion/developer_sku_helper.py +++ b/src/bastion/azext_bastion/developer_sku_helper.py @@ -1,3 +1,10 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +# pylint: disable=import-error,unused-import + def _get_data_pod(cmd, resource_port, target_resource_id, bastion): from azure.cli.core._profile import Profile from azure.cli.core.util import should_disable_connection_verify diff --git a/src/bastion/azext_bastion/tunnel.py b/src/bastion/azext_bastion/tunnel.py index 01939b28770..e5a40bbf802 100644 --- a/src/bastion/azext_bastion/tunnel.py +++ b/src/bastion/azext_bastion/tunnel.py @@ -24,6 +24,7 @@ from websocket import create_connection, WebSocket from msrestazure.azure_exceptions import CloudError +from .BastionServiceConstants import BastionSku from azure.cli.core._profile import Profile from azure.cli.core.util import should_disable_connection_verify @@ -119,7 +120,7 @@ def _listen(self): self.client, _address = self.sock.accept() auth_token = self._get_auth_token() - if self.bastion['sku']['name'] == "QuickConnect" or self.bastion['sku']['name'] == "Developer": + if self.bastion['sku']['name'] == BastionSku.QuickConnect.name or self.bastion['sku']['name'] == BastionSku.Developer.name: host = f"wss://{self.bastion_endpoint}/omni/webtunnel/{auth_token}" else: host = f"wss://{self.bastion_endpoint}/webtunnelv2/{auth_token}?X-Node-Id={self.node_id}" From 8a8c44856b1cf6e51363c7adcf254c971de62866 Mon Sep 17 00:00:00 2001 From: Aakash Radhakrishnan Date: Wed, 25 Jan 2023 17:04:16 -0800 Subject: [PATCH 5/7] Documentation and final changes --- src/bastion/HISTORY.rst | 6 ++++++ src/bastion/README.md | 10 ++++++++++ .../azext_bastion/BastionServiceConstants.py | 12 +++++++----- src/bastion/azext_bastion/_help.py | 13 +++++++++++-- src/bastion/azext_bastion/custom.py | 13 +++++-------- 5 files changed, 39 insertions(+), 15 deletions(-) diff --git a/src/bastion/HISTORY.rst b/src/bastion/HISTORY.rst index 8c34bccfff8..988dd1b1584 100644 --- a/src/bastion/HISTORY.rst +++ b/src/bastion/HISTORY.rst @@ -3,6 +3,12 @@ Release History =============== +0.2.0 +++++++ +* Adding support for IP connect through AZ CLI. +* Initial support for connectivity through developerSku. +* Bug fixes. + 0.1.0 ++++++ * Initial release. \ No newline at end of file diff --git a/src/bastion/README.md b/src/bastion/README.md index 3912f3a39df..9f55ee0e9c7 100644 --- a/src/bastion/README.md +++ b/src/bastion/README.md @@ -28,3 +28,13 @@ az network bastion show --name MyBastionHost --resource-group MyResourceGroup ```commandline az network bastion update --name MyBastionHost --resource-group MyResourceGroup --enable-tunneling ``` + +### RDP to VM/VMSS using Azure Bastion host machine +```commandline +az network bastion rdp --name MyBastionHost --resource-group MyResourceGroup --target-resource-id ResourceId +``` + +### SSH to VM/VMSS using Azure Bastion host machine +```commandline +az network bastion ssh --name MyBastionHost --resource-group MyResourceGroup --enable-tunneling --target-resource-id ResourceId --auth-type password +``` diff --git a/src/bastion/azext_bastion/BastionServiceConstants.py b/src/bastion/azext_bastion/BastionServiceConstants.py index 940d06b75ca..40db1bdec9b 100644 --- a/src/bastion/azext_bastion/BastionServiceConstants.py +++ b/src/bastion/azext_bastion/BastionServiceConstants.py @@ -6,9 +6,11 @@ # pylint: disable=import-error,unused-import from enum import Enum - + + class BastionSku(Enum): - Basic = 1 - Standard = 2 - Developer = 3 - QuickConnect = 4 \ No newline at end of file + + Basic = "Basic" + Standard = "Standard" + Developer = "Developer" + QuickConnect = "QuickConnect" diff --git a/src/bastion/azext_bastion/_help.py b/src/bastion/azext_bastion/_help.py index 23e67da89f0..b15191cb020 100644 --- a/src/bastion/azext_bastion/_help.py +++ b/src/bastion/azext_bastion/_help.py @@ -17,13 +17,16 @@ examples: - name: SSH to virtual machine using Azure Bastion using password. text: | - az network bastion ssh --name MyBastionHost --resource-group MyResourceGroup --target-resource-id vmResourceId --auth-type password --username xyz + az network bastion ssh --name MyBastionHost --resource-group MyResourceGroup --target-resource-id /subscriptions/ vmResourceId --auth-type password --username xyz - name: SSH to virtual machine using Azure Bastion using ssh key file. text: | az network bastion ssh --name MyBastionHost --resource-group MyResourceGroup --target-resource-id vmResourceId --auth-type ssh-key --username xyz --ssh-key C:/filepath/sshkey.pem - name: SSH to virtual machine using Azure Bastion using AAD. text: | az network bastion ssh --name MyBastionHost --resource-group MyResourceGroup --target-resource-id vmResourceId --auth-type AAD + - name: SSH to virtual machine using Azure Bastion using AAD. + text: | + az network bastion ssh --name MyBastionHost --resource-group MyResourceGroup --target-resource-id vmResourceId --auth-type AAD """ helps['network bastion rdp'] = """ @@ -33,13 +36,19 @@ - name: RDP to virtual machine using Azure Bastion. text: | az network bastion rdp --name MyBastionHost --resource-group MyResourceGroup --target-resource-id vmResourceId + - name: RDP to machine using reachable IP address. + text: | + az network bastion rdp --name MyBastionHost --resource-group MyResourceGroup --target-ip-address 10.0.0.1 """ helps['network bastion tunnel'] = """ type: command short-summary: Open a tunnel through Azure Bastion to a target virtual machine. examples: - - name: Open a tunnel through Azure Bastion to a target virtual machine. + - name: Open a tunnel through Azure Bastion to a target virtual machine using resourceId. text: | az network bastion tunnel --name MyBastionHost --resource-group MyResourceGroup --target-resource-id vmResourceId --resource-port 22 --port 50022 + - name: Open a tunnel through Azure Bastion to a target virtual machine using its IP address. + text: | + az network bastion tunnel --name MyBastionHost --resource-group MyResourceGroup --target-ip-address 10.0.0.1 --resource-port 22 --port 50022 """ diff --git a/src/bastion/azext_bastion/custom.py b/src/bastion/azext_bastion/custom.py index 81f654a77ae..e6f6712625d 100644 --- a/src/bastion/azext_bastion/custom.py +++ b/src/bastion/azext_bastion/custom.py @@ -147,7 +147,7 @@ def ssh_bastion_host(cmd, auth_type, target_resource_id, target_ip_address, reso if not resource_port: resource_port = 22 - if bastion['sku']['name'] == BastionSku.Basic.name or bastion['sku']['name'] == BastionSku.Standard.name and bastion['enableTunneling'] is not True: + if bastion['sku']['name'] == BastionSku.Basic.value or bastion['sku']['name'] == BastionSku.Standard.value and bastion['enableTunneling'] is not True: raise ClientRequestError('Bastion Host SKU must be Standard and Native Client must be enabled.') _validate_and_generate_resourceid(cmd, bastion, target_resource_id, target_ip_address) @@ -231,7 +231,7 @@ def rdp_bastion_host(cmd, target_resource_id, target_ip_address, resource_group_ if not resource_port: resource_port = 3389 - if bastion['sku']['name'] == BastionSku.Basic.name or bastion['sku']['name'] == BastionSku.Standard.name and bastion['enableTunneling'] is not True: + if bastion['sku']['name'] == BastionSku.Basic.value or bastion['sku']['name'] == BastionSku.Standard.value and bastion['enableTunneling'] is not True: raise ClientRequestError('Bastion Host SKU must be Standard and Native Client must be enabled.') ip_connect = _is_ipconnect_request(cmd, bastion, target_ip_address) @@ -253,10 +253,7 @@ def rdp_bastion_host(cmd, target_resource_id, target_ip_address, resource_group_ profile = Profile(cli_ctx=cmd.cli_ctx) access_token = profile.get_raw_token()[0][2].get("accessToken") logger.debug("Response %s", access_token) - if bastion['sku']['name'] == BastionSku.QuickConnect.name or bastion['sku']['name'] == BastionSku.Developer.name: - web_address = f"https://{bastion_endpoint}/api/omni/rdpfile?resourceId={target_resource_id}&format=rdp&rdpport={resource_port}&enablerdsaad={enable_mfa}" - else: - web_address = f"https://{bastion_endpoint}/api/rdpfile?resourceId={target_resource_id}&format=rdp&rdpport={resource_port}&enablerdsaad={enable_mfa}" + web_address = f"https://{bastion_endpoint}/api/rdpfile?resourceId={target_resource_id}&format=rdp&rdpport={resource_port}&enablerdsaad={enable_mfa}" headers = { "Authorization": f"Bearer {access_token}", @@ -299,7 +296,7 @@ def _validate_and_generate_resourceid(cmd, bastion, resource_group_name, target_ def _get_bastion_endpoint(cmd, bastion, resource_port, target_resource_id): - if bastion['sku']['name'] == BastionSku.QuickConnect.name or bastion['sku']['name'] == BastionSku.Developer.name: + if bastion['sku']['name'] == BastionSku.QuickConnect.value or bastion['sku']['name'] == BastionSku.Developer.value: from .developer_sku_helper import (_get_data_pod) bastion_endpoint = _get_data_pod(cmd, resource_port, target_resource_id, bastion) return bastion_endpoint @@ -346,7 +343,7 @@ def create_bastion_tunnel(cmd, target_resource_id, target_ip_address, resource_g "name": bastion_host_name }) - if bastion['sku']['name'] == BastionSku.Basic.name or bastion['sku']['name'] == BastionSku.Standard.name and bastion['enableTunneling'] is not True: + if bastion['sku']['name'] == BastionSku.Basic.value or bastion['sku']['name'] == BastionSku.Standard.value and bastion['enableTunneling'] is not True: raise ClientRequestError('Bastion Host SKU must be Standard and Native Client must be enabled.') _validate_and_generate_resourceid(cmd, bastion, target_resource_id, target_ip_address) From 58c5e8543ec9058ce2411dd5010fd043c017ec43 Mon Sep 17 00:00:00 2001 From: Aakash Radhakrishnan Date: Wed, 25 Jan 2023 17:11:59 -0800 Subject: [PATCH 6/7] Documentation and final changes --- src/bastion/azext_bastion/_help.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bastion/azext_bastion/_help.py b/src/bastion/azext_bastion/_help.py index b15191cb020..83cecad2f2c 100644 --- a/src/bastion/azext_bastion/_help.py +++ b/src/bastion/azext_bastion/_help.py @@ -17,7 +17,7 @@ examples: - name: SSH to virtual machine using Azure Bastion using password. text: | - az network bastion ssh --name MyBastionHost --resource-group MyResourceGroup --target-resource-id /subscriptions/ vmResourceId --auth-type password --username xyz + az network bastion ssh --name MyBastionHost --resource-group MyResourceGroup --target-resource-id vmResourceId --auth-type password --username xyz - name: SSH to virtual machine using Azure Bastion using ssh key file. text: | az network bastion ssh --name MyBastionHost --resource-group MyResourceGroup --target-resource-id vmResourceId --auth-type ssh-key --username xyz --ssh-key C:/filepath/sshkey.pem From 3441d0e183b7ee4f2a368b99df45e84f017f8f3d Mon Sep 17 00:00:00 2001 From: Aakash Radhakrishnan Date: Wed, 25 Jan 2023 17:37:29 -0800 Subject: [PATCH 7/7] Documentation and final changes --- src/bastion/azext_bastion/developer_sku_helper.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/bastion/azext_bastion/developer_sku_helper.py b/src/bastion/azext_bastion/developer_sku_helper.py index 874d8c1472c..f4dc92d07a7 100644 --- a/src/bastion/azext_bastion/developer_sku_helper.py +++ b/src/bastion/azext_bastion/developer_sku_helper.py @@ -5,6 +5,7 @@ # pylint: disable=import-error,unused-import + def _get_data_pod(cmd, resource_port, target_resource_id, bastion): from azure.cli.core._profile import Profile from azure.cli.core.util import should_disable_connection_verify