Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/bastion/HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
Release History
===============

0.2.0
++++++
* Adding support for IP connect through AZ CLI.
* Initial support for connectivity through developerSku.
* Bug fixes.

0.1.0
++++++
* Initial release.
10 changes: 10 additions & 0 deletions src/bastion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,13 @@ az network bastion show --name MyBastionHost --resource-group MyResourceGroup
```commandline
az network bastion update --name MyBastionHost --resource-group MyResourceGroup --enable-tunneling
```

### RDP to VM/VMSS using Azure Bastion host machine
```commandline
az network bastion rdp --name MyBastionHost --resource-group MyResourceGroup --target-resource-id ResourceId
```

### SSH to VM/VMSS using Azure Bastion host machine
```commandline
az network bastion ssh --name MyBastionHost --resource-group MyResourceGroup --enable-tunneling --target-resource-id ResourceId --auth-type password
```
16 changes: 16 additions & 0 deletions src/bastion/azext_bastion/BastionServiceConstants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------

# pylint: disable=import-error,unused-import

from enum import Enum


class BastionSku(Enum):

Basic = "Basic"
Standard = "Standard"
Developer = "Developer"
QuickConnect = "QuickConnect"
11 changes: 10 additions & 1 deletion src/bastion/azext_bastion/_help.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
- name: SSH to virtual machine using Azure Bastion using AAD.
text: |
az network bastion ssh --name MyBastionHost --resource-group MyResourceGroup --target-resource-id vmResourceId --auth-type AAD
- name: SSH to virtual machine using Azure Bastion using AAD.
text: |
az network bastion ssh --name MyBastionHost --resource-group MyResourceGroup --target-resource-id vmResourceId --auth-type AAD
"""

helps['network bastion rdp'] = """
Expand All @@ -33,13 +36,19 @@
- name: RDP to virtual machine using Azure Bastion.
text: |
az network bastion rdp --name MyBastionHost --resource-group MyResourceGroup --target-resource-id vmResourceId
- name: RDP to machine using reachable IP address.
text: |
az network bastion rdp --name MyBastionHost --resource-group MyResourceGroup --target-ip-address 10.0.0.1
"""

helps['network bastion tunnel'] = """
type: command
short-summary: Open a tunnel through Azure Bastion to a target virtual machine.
examples:
- name: Open a tunnel through Azure Bastion to a target virtual machine.
- name: Open a tunnel through Azure Bastion to a target virtual machine using resourceId.
text: |
az network bastion tunnel --name MyBastionHost --resource-group MyResourceGroup --target-resource-id vmResourceId --resource-port 22 --port 50022
- name: Open a tunnel through Azure Bastion to a target virtual machine using its IP address.
text: |
az network bastion tunnel --name MyBastionHost --resource-group MyResourceGroup --target-ip-address 10.0.0.1 --resource-port 22 --port 50022
"""
5 changes: 4 additions & 1 deletion src/bastion/azext_bastion/_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from azure.cli.core.commands.parameters import get_resource_name_completion_list, get_three_state_flag
from knack.arguments import CLIArgumentType
from ._validators import (validate_ip_address)


def load_arguments(self, _): # pylint: disable=unused-argument
Expand All @@ -24,8 +25,10 @@ def load_arguments(self, _): # pylint: disable=unused-argument
c.argument("bastion_host_name", bastion_host_name_type, options_list=["--name", "-n"])
c.argument("resource_port", help="Resource port of the target VM to which the bastion will connect.",
options_list=["--resource-port"])
c.argument("target_resource_id", help="ResourceId of the target Virtual Machine.",
c.argument("target_resource_id", help="ResourceId of the target Virtual Machine.", required=False,
options_list=["--target-resource-id"])
c.argument("target_ip_address", help="IP address of target Virtual Machine.", required=False,
options_list=["--target-ip-address"], validator=validate_ip_address)

with self.argument_context("network bastion ssh") as c:
c.argument("auth_type", help="Auth type to use for SSH connections.", options_list=["--auth-type"])
Expand Down
28 changes: 28 additions & 0 deletions src/bastion/azext_bastion/_validators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------

import ipaddress
from azure.cli.core.azclierror import InvalidArgumentValueError


def validate_ip_address(namespace):
if namespace.target_ip_address is not None:
_validate_ip_address_format(namespace)


def _validate_ip_address_format(namespace):
if namespace.target_ip_address is not None:
input_value = namespace.target_ip_address
if ' ' in input_value:
raise InvalidArgumentValueError("Spaces not allowed: '{}' ".format(input_value))
input_ips = input_value.split(',')
if len(input_ips) > 8:
raise InvalidArgumentValueError('Maximum 8 IP addresses are allowed per rule.')
validated_ips = ''
for ip in input_ips:
# Use ipaddress library to validate ip network format
ip_obj = ipaddress.ip_network(ip)
validated_ips += str(ip_obj) + ','
namespace.target_ip_address = validated_ips[:-1]
105 changes: 73 additions & 32 deletions src/bastion/azext_bastion/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
import requests
from azure.cli.core.azclierror import ValidationError, InvalidArgumentValueError, RequiredArgumentMissingError, \
UnrecognizedArgumentError, CLIInternalError, ClientRequestError
from azure.cli.core.commands.client_factory import get_subscription_id
from knack.log import get_logger
from msrestazure.tools import is_valid_resource_id

from .BastionServiceConstants import BastionSku
from .aaz.latest.network.bastion import Create as _BastionCreate


Expand Down Expand Up @@ -132,20 +133,27 @@ def _build_args(cert_file, private_key_file):
return private_key + certificate


def ssh_bastion_host(cmd, auth_type, target_resource_id, resource_group_name, bastion_host_name,
def ssh_bastion_host(cmd, auth_type, target_resource_id, target_ip_address, resource_group_name, bastion_host_name,
resource_port=None, username=None, ssh_key=None):
import os
from .aaz.latest.network.bastion import Show

_test_extension(SSH_EXTENSION_NAME)
bastion = Show(cli_ctx=cmd.cli_ctx)(command_args={
"resource_group": resource_group_name,
"name": bastion_host_name
})

if not resource_port:
resource_port = 22
if not is_valid_resource_id(target_resource_id):
err_msg = "Please enter a valid resource ID. If this is not working, " \
"try opening the JSON view of your resource (in the Overview tab), and copying the full resource ID."
raise InvalidArgumentValueError(err_msg)

tunnel_server = _get_tunnel(cmd, resource_group_name, bastion_host_name, target_resource_id, resource_port)
if bastion['sku']['name'] == BastionSku.Basic.value or bastion['sku']['name'] == BastionSku.Standard.value and bastion['enableTunneling'] is not True:
raise ClientRequestError('Bastion Host SKU must be Standard and Native Client must be enabled.')

_validate_and_generate_resourceid(cmd, bastion, target_resource_id, target_ip_address)
bastion_endpoint = _get_bastion_endpoint(cmd, bastion, resource_port, target_resource_id)

tunnel_server = _get_tunnel(cmd, bastion, bastion_endpoint, target_resource_id, resource_port)
t = threading.Thread(target=_start_tunnel, args=(tunnel_server,))
t.daemon = True
t.start()
Expand Down Expand Up @@ -208,32 +216,33 @@ def _get_rdp_path(rdp_command="mstsc"):
return rdp_path


def rdp_bastion_host(cmd, target_resource_id, resource_group_name, bastion_host_name,
def rdp_bastion_host(cmd, target_resource_id, target_ip_address, resource_group_name, bastion_host_name,
resource_port=None, disable_gateway=False, configure=False, enable_mfa=False):
import os
from azure.cli.core._profile import Profile
from ._process_helper import launch_and_wait

if not resource_port:
resource_port = 3389
if not is_valid_resource_id(target_resource_id):
err_msg = "Please enter a valid resource ID. If this is not working, " \
"try opening the JSON view of your resource (in the Overview tab), and copying the full resource ID."
raise InvalidArgumentValueError(err_msg)

from .aaz.latest.network.bastion import Show

bastion = Show(cli_ctx=cmd.cli_ctx)(command_args={
"resource_group": resource_group_name,
"name": bastion_host_name
})

if bastion['sku']['name'] == "Basic" or \
bastion['sku']['name'] == "Standard" and bastion['enableTunneling'] is not True:
if not resource_port:
resource_port = 3389

if bastion['sku']['name'] == BastionSku.Basic.value or bastion['sku']['name'] == BastionSku.Standard.value and bastion['enableTunneling'] is not True:
raise ClientRequestError('Bastion Host SKU must be Standard and Native Client must be enabled.')

ip_connect = _is_ipconnect_request(cmd, bastion, target_ip_address)
_validate_and_generate_resourceid(cmd, bastion, resource_group_name, target_resource_id, target_ip_address)
bastion_endpoint = _get_bastion_endpoint(cmd, bastion, resource_port, target_resource_id)

if platform.system() == "Windows":
if disable_gateway:
tunnel_server = _get_tunnel(cmd, resource_group_name, bastion_host_name, target_resource_id, resource_port)
if disable_gateway or ip_connect:
tunnel_server = _get_tunnel(cmd, bastion, bastion_endpoint, target_resource_id, resource_port, bastion_endpoint)
if ip_connect:
tunnel_server.set_host_name(target_ip_address)
t = threading.Thread(target=_start_tunnel, args=(tunnel_server,))
t.daemon = True
t.start()
Expand All @@ -244,9 +253,8 @@ def rdp_bastion_host(cmd, target_resource_id, resource_group_name, bastion_host_
profile = Profile(cli_ctx=cmd.cli_ctx)
access_token = profile.get_raw_token()[0][2].get("accessToken")
logger.debug("Response %s", access_token)
web_address = f"https://{bastion_endpoint}/api/rdpfile?resourceId={target_resource_id}&format=rdp&rdpport={resource_port}&enablerdsaad={enable_mfa}"

web_address = f"https://{bastion['dnsName']}/api/rdpfile?resourceId={target_resource_id}" \
f"&format=rdp&rdpport={resource_port}&enablerdsaad={enable_mfa}"
headers = {
"Authorization": f"Bearer {access_token}",
"Accept": "*/*",
Expand All @@ -259,7 +267,6 @@ def rdp_bastion_host(cmd, target_resource_id, resource_group_name, bastion_host_
raise ClientRequestError("Request to EncodingReservedUnitTypes v2 API endpoint failed.")

_write_to_file(response)

rdpfilepath = os.getcwd() + "/conn.rdp"
command = [_get_rdp_path()]
if configure:
Expand All @@ -270,24 +277,46 @@ def rdp_bastion_host(cmd, target_resource_id, resource_group_name, bastion_host_
raise UnrecognizedArgumentError("Platform is not supported for this command. Supported platforms: Windows")


def _is_ipconnect_request(cmd, bastion, target_ip_address):
if bastion['enableIpConnect'] is True and target_ip_address:
return True

return False


def _validate_and_generate_resourceid(cmd, bastion, resource_group_name, target_resource_id, target_ip_address):
if target_ip_address:
if bastion['enableIpConnect'] is not True:
raise InvalidArgumentValueError("Bastion does not have IP Connect feature enabled, please enable and try again")
target_resource_id = f"/subscriptions/{get_subscription_id(cmd.cli_ctx)}/resourceGroups/{resource_group_name}/providers/Microsoft.Network/bh-hostConnect/{target_ip_address}"
elif not is_valid_resource_id(target_resource_id):
err_msg = "Please enter a valid resource ID. If this is not working, " \
"try opening the JSON view of your resource (in the Overview tab), and copying the full resource ID."
raise InvalidArgumentValueError(err_msg)


def _get_bastion_endpoint(cmd, bastion, resource_port, target_resource_id):
if bastion['sku']['name'] == BastionSku.QuickConnect.value or bastion['sku']['name'] == BastionSku.Developer.value:
from .developer_sku_helper import (_get_data_pod)
bastion_endpoint = _get_data_pod(cmd, resource_port, target_resource_id, bastion)
return bastion_endpoint

return bastion['dnsName']


def _write_to_file(response):
with open("conn.rdp", "w", encoding="utf-8") as f:
for line in response.text.splitlines():
if not line.startswith('signscope'):
f.write(line + "\n")


def _get_tunnel(cmd, resource_group_name, name, vm_id, resource_port, port=None):
def _get_tunnel(cmd, bastion, bastion_endpoint, vm_id, resource_port, port=None):
from .tunnel import TunnelServer
from .aaz.latest.network.bastion import Show

bastion = Show(cli_ctx=cmd.cli_ctx)(command_args={
"resource_group": resource_group_name,
"name": name
})
if port is None:
port = 0 # will auto-select a free port from 1024-65535
tunnel_server = TunnelServer(cmd.cli_ctx, "localhost", port, bastion, vm_id, resource_port)
tunnel_server = TunnelServer(cmd.cli_ctx, "localhost", port, bastion, bastion_endpoint, vm_id, resource_port)

return tunnel_server

Expand All @@ -303,12 +332,24 @@ def _tunnel_close_handler(tunnel):
sys.exit()


def create_bastion_tunnel(cmd, target_resource_id, resource_group_name, bastion_host_name, resource_port, port,
def create_bastion_tunnel(cmd, target_resource_id, target_ip_address, resource_group_name, bastion_host_name, resource_port, port,
timeout=None):
if not is_valid_resource_id(target_resource_id):
raise InvalidArgumentValueError("Please enter a valid VM resource ID.")

tunnel_server = _get_tunnel(cmd, resource_group_name, bastion_host_name, target_resource_id, resource_port, port)
from .aaz.latest.network.bastion import Show
bastion = Show(cli_ctx=cmd.cli_ctx)(command_args={
"resource_group": resource_group_name,
"name": bastion_host_name
})

if bastion['sku']['name'] == BastionSku.Basic.value or bastion['sku']['name'] == BastionSku.Standard.value and bastion['enableTunneling'] is not True:
raise ClientRequestError('Bastion Host SKU must be Standard and Native Client must be enabled.')

_validate_and_generate_resourceid(cmd, bastion, target_resource_id, target_ip_address)
bastion_endpoint = _get_bastion_endpoint(cmd, bastion, resource_port, target_resource_id)

tunnel_server = _get_tunnel(cmd, bastion, bastion_endpoint, target_resource_id, resource_port, port)
t = threading.Thread(target=_start_tunnel, args=(tunnel_server,))
t.daemon = True
t.start()
Expand Down
29 changes: 29 additions & 0 deletions src/bastion/azext_bastion/developer_sku_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------

# pylint: disable=import-error,unused-import


def _get_data_pod(cmd, resource_port, target_resource_id, bastion):
from azure.cli.core._profile import Profile
from azure.cli.core.util import should_disable_connection_verify
import requests

profile = Profile(cli_ctx=cmd.cli_ctx)
auth_token, _, _ = profile.get_raw_token()
content = {
'resourceId': target_resource_id,
'bastionResourceId': bastion.id,
'vmPort': resource_port,
'azToken': auth_token[1],
'connectionType': 'nativeclient'
}
headers = {'Content-Type': 'application/json'}

web_address = f"https://{bastion['dnsName']}/api/connection"
response = requests.post(web_address, json=content, headers=headers,
verify=(not should_disable_connection_verify()))

return response.content.decode("utf-8")
Loading