Skip to content
5 changes: 4 additions & 1 deletion src/azure-cli-core/azure/cli/core/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(self, # pylint: disable=unused-argument
synapse_analytics_resource_id=None,
attestation_resource_id=None,
portal=None,
azmirror_storage_account_resource_id=None,
**kwargs): # To support init with __dict__ for deserialization
# Attribute names are significant. They are used when storing/retrieving clouds from config
self.management = management
Expand All @@ -100,6 +101,7 @@ def __init__(self, # pylint: disable=unused-argument
self.synapse_analytics_resource_id = synapse_analytics_resource_id
self.attestation_resource_id = attestation_resource_id
self.portal = portal
self.azmirror_storage_account_resource_id = azmirror_storage_account_resource_id

def has_endpoint_set(self, endpoint_name):
try:
Expand Down Expand Up @@ -239,7 +241,8 @@ def _arm_to_cli_mapper(arm_dict):
synapse_analytics_resource_id=get_endpoint('synapseAnalyticsResourceId', fallback_value=get_endpoint_fallback_value('synapse_analytics_resource_id')),
app_insights_telemetry_channel_resource_id=get_endpoint('appInsightsTelemetryChannelResourceId', fallback_value=get_endpoint_fallback_value('app_insights_telemetry_channel_resource_id')),
attestation_resource_id=get_endpoint('attestationResourceId', fallback_value=get_endpoint_fallback_value('attestation_resource_id')),
portal=get_endpoint('portal')),
portal=get_endpoint('portal'),
azmirror_storage_account_resource_id=get_endpoint('azmirrorStorageAccountResourceId')),
suffixes=CloudSuffixes(
storage_endpoint=get_suffix('storage'),
storage_sync_endpoint=get_suffix('storageSyncEndpointSuffix', fallback_value=get_suffix_fallback_value('storage_sync_endpoint')),
Expand Down
25 changes: 21 additions & 4 deletions src/azure-cli-core/azure/cli/core/extension/_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,28 @@
TRIES = 3


def get_index_url(cli_ctx=None):
"""Use extension index url in the order of:
1. Environment variable: AZURE_EXTENSION_INDEX_URL
2. Config setting: extension.index_url
3. Index file in azmirror storage account cloud endpoint
4. DEFAULT_INDEX_URL
"""
import posixpath
if cli_ctx:
url = cli_ctx.config.get('extension', 'index_url', None)
if url:
return url
azmirror_endpoint = cli_ctx.cloud.endpoints.azmirror_storage_account_resource_id if cli_ctx and \
cli_ctx.cloud.endpoints.has_endpoint_set('azmirror_storage_account_resource_id') else None
return posixpath.join(azmirror_endpoint, 'extensions', 'index.json') if azmirror_endpoint else DEFAULT_INDEX_URL


# pylint: disable=inconsistent-return-statements
def get_index(index_url=None):
def get_index(index_url=None, cli_ctx=None):
import requests
from azure.cli.core.util import should_disable_connection_verify
index_url = index_url or DEFAULT_INDEX_URL
index_url = index_url or get_index_url(cli_ctx=cli_ctx)

for try_number in range(TRIES):
try:
Expand All @@ -45,8 +62,8 @@ def get_index(index_url=None):
continue


def get_index_extensions(index_url=None):
index = get_index(index_url=index_url)
def get_index_extensions(index_url=None, cli_ctx=None):
index = get_index(index_url=index_url, cli_ctx=cli_ctx)
extensions = index.get('extensions')
if extensions is None:
logger.warning(ERR_UNABLE_TO_GET_EXTENSIONS)
Expand Down
14 changes: 12 additions & 2 deletions src/azure-cli-core/azure/cli/core/extension/_resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,13 @@ def filter_func(item):
return filter_func


def resolve_from_index(extension_name, cur_version=None, index_url=None, target_version=None):
def resolve_from_index(extension_name, cur_version=None, index_url=None, target_version=None, cli_ctx=None):
"""
Gets the download Url and digest for the matching extension

:param cur_version: threshold verssion to filter out extensions.
"""
candidates = get_index_extensions(index_url=index_url).get(extension_name, [])
candidates = get_index_extensions(index_url=index_url, cli_ctx=cli_ctx).get(extension_name, [])

if not candidates:
raise NoExtensionCandidatesError("No extension found with name '{}'".format(extension_name))
Expand Down Expand Up @@ -90,6 +90,16 @@ def resolve_from_index(extension_name, cur_version=None, index_url=None, target_
download_url, digest = chosen.get('downloadUrl'), chosen.get('sha256Digest')
if not download_url:
raise NoExtensionCandidatesError("No download url found.")
azmirror_endpoint = cli_ctx.cloud.endpoints.azmirror_storage_account_resource_id if cli_ctx and \
cli_ctx.cloud.endpoints.has_endpoint_set('azmirror_storage_account_resource_id') else None
config_index_url = cli_ctx.config.get('extension', 'index_url', None) if cli_ctx else None
if azmirror_endpoint and not config_index_url:
# when extension index and wheels are mirrored in airgapped clouds from public cloud
# the content of the index.json is not updated, so we need to modify the wheel url got
# from the index.json here.
import posixpath
whl_name = download_url.split('/')[-1]
download_url = posixpath.join(azmirror_endpoint, 'extensions', whl_name)
return download_url, digest


Expand Down
12 changes: 6 additions & 6 deletions src/azure-cli-core/azure/cli/core/extension/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def add_extension(cmd=None, source=None, extension_name=None, index_url=None, ye
return
logger.warning("Overriding development version of '%s' with production version.", extension_name)
try:
source, ext_sha256 = resolve_from_index(extension_name, index_url=index_url, target_version=version)
source, ext_sha256 = resolve_from_index(extension_name, index_url=index_url, target_version=version, cli_ctx=cmd_cli_ctx)
except NoExtensionCandidatesError as err:
logger.debug(err)

Expand Down Expand Up @@ -373,7 +373,7 @@ def update_extension(cmd=None, extension_name=None, index_url=None, pip_extra_in
ext = get_extension(extension_name, ext_type=WheelExtension)
cur_version = ext.get_version()
try:
download_url, ext_sha256 = resolve_from_index(extension_name, cur_version=cur_version, index_url=index_url, target_version=version)
download_url, ext_sha256 = resolve_from_index(extension_name, cur_version=cur_version, index_url=index_url, target_version=version, cli_ctx=cmd_cli_ctx)
except NoExtensionCandidatesError as err:
logger.debug(err)
msg = "Extension {} with version {} not found.".format(extension_name, version) if version else "No updates available for '{}'. Use --debug for more information.".format(extension_name)
Expand Down Expand Up @@ -405,8 +405,8 @@ def update_extension(cmd=None, extension_name=None, index_url=None, pip_extra_in
raise CLIError(e)


def list_available_extensions(index_url=None, show_details=False):
index_data = get_index_extensions(index_url=index_url)
def list_available_extensions(index_url=None, show_details=False, cli_ctx=None):
index_data = get_index_extensions(index_url=index_url, cli_ctx=cli_ctx)
if show_details:
return index_data
installed_extensions = get_extensions(ext_type=WheelExtension)
Expand Down Expand Up @@ -436,8 +436,8 @@ def list_available_extensions(index_url=None, show_details=False):
return results


def list_versions(extension_name, index_url=None):
index_data = get_index_extensions(index_url=index_url)
def list_versions(extension_name, index_url=None, cli_ctx=None):
index_data = get_index_extensions(index_url=index_url, cli_ctx=cli_ctx)

try:
exts = index_data[extension_name]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,19 +332,19 @@ def test_update_extension_exception_in_update_and_rolled_back(self):

def test_list_available_extensions_default(self):
with mock.patch('azure.cli.core.extension.operations.get_index_extensions', autospec=True) as c:
list_available_extensions()
c.assert_called_once_with(None)
list_available_extensions(cli_ctx=self.cmd.cli_ctx)
c.assert_called_once_with(None, self.cmd.cli_ctx)

def test_list_available_extensions_operations_index_url(self):
with mock.patch('azure.cli.core.extension.operations.get_index_extensions', autospec=True) as c:
index_url = 'http://contoso.com'
list_available_extensions(index_url=index_url)
c.assert_called_once_with(index_url)
list_available_extensions(index_url=index_url, cli_ctx=self.cmd.cli_ctx)
c.assert_called_once_with(index_url, self.cmd.cli_ctx)

def test_list_available_extensions_show_details(self):
with mock.patch('azure.cli.core.extension.operations.get_index_extensions', autospec=True) as c:
list_available_extensions(show_details=True)
c.assert_called_once_with(None)
list_available_extensions(show_details=True, cli_ctx=self.cmd.cli_ctx)
c.assert_called_once_with(None, self.cmd.cli_ctx)

def test_list_available_extensions_no_show_details(self):
sample_index_extensions = {
Expand All @@ -364,7 +364,7 @@ def test_list_available_extensions_no_show_details(self):
}}]
}
with mock.patch('azure.cli.core.extension.operations.get_index_extensions', return_value=sample_index_extensions):
res = list_available_extensions()
res = list_available_extensions(cli_ctx=self.cmd.cli_ctx)
self.assertIsInstance(res, list)
self.assertEqual(len(res), len(sample_index_extensions))
self.assertEqual(res[0]['name'], 'test_sample_extension1')
Expand All @@ -373,7 +373,7 @@ def test_list_available_extensions_no_show_details(self):
self.assertEqual(res[0]['preview'], False)
self.assertEqual(res[0]['experimental'], False)
with mock.patch('azure.cli.core.extension.operations.get_index_extensions', return_value=sample_index_extensions):
res = list_available_extensions()
res = list_available_extensions(cli_ctx=self.cmd.cli_ctx)
self.assertIsInstance(res, list)
self.assertEqual(len(res), len(sample_index_extensions))
self.assertEqual(res[1]['name'], 'test_sample_extension2')
Expand All @@ -393,7 +393,7 @@ def test_list_available_extensions_incompatible_cli_version(self):
}}]
}
with mock.patch('azure.cli.core.extension.operations.get_index_extensions', return_value=sample_index_extensions):
res = list_available_extensions()
res = list_available_extensions(cli_ctx=self.cmd.cli_ctx)
self.assertIsInstance(res, list)
self.assertEqual(len(res), 0)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,27 @@ def test_get_index_extensions(self):
self.assertEqual(get_index_extensions(), None)
logger_mock.assert_called_once_with(ERR_UNABLE_TO_GET_EXTENSIONS)

# pylint: disable=line-too-long
def test_get_index_cloud(self):

from azure.cli.core.mock import DummyCli
cli_ctx = DummyCli()

default_data = {'extensions': {}}
obj = object()
cloud_data = {'extensions': {'myext': obj}}
# cli_ctx not passed
with mock.patch('requests.get', side_effect=mock_index_get_generator(DEFAULT_INDEX_URL, default_data)):
self.assertEqual(get_index_extensions(), {})
# cli_ctx passed but endpoint not set
delattr(cli_ctx.cloud.endpoints, 'azmirror_storage_account_resource_id')
with mock.patch('requests.get', side_effect=mock_index_get_generator(DEFAULT_INDEX_URL, default_data)):
self.assertEqual(get_index_extensions(cli_ctx=cli_ctx), {})
# cli_ctx passed and the endpoint is set
cli_ctx.cloud.endpoints.azmirror_storage_account_resource_id = 'http://contoso.com'
with mock.patch('requests.get', side_effect=mock_index_get_generator('http://contoso.com/extensions/index.json', cloud_data)):
self.assertEqual(get_index_extensions(cli_ctx=cli_ctx).get('myext'), obj)


if __name__ == '__main__':
unittest.main()
7 changes: 6 additions & 1 deletion src/azure-cli-core/azure/cli/core/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,11 +298,16 @@ def _get_extension_command_tree(self):
return None
EXT_CMD_TREE.load(os.path.join(cli_ctx.config.config_dir, 'extensionCommandTree.json'), VALID_SECOND)
if not EXT_CMD_TREE.data:
import posixpath
import requests
from azure.cli.core.util import should_disable_connection_verify
try:
azmirror_endpoint = cli_ctx.cloud.endpoints.azmirror_storage_account_resource_id if cli_ctx and \
cli_ctx.cloud.endpoints.has_endpoint_set('azmirror_storage_account_resource_id') else None
url = posixpath.join(azmirror_endpoint, 'extensions', 'extensionCommandTree.json') if \
azmirror_endpoint else 'https://aka.ms/azExtCmdTree'
response = requests.get(
'https://aka.ms/azExtCmdTree',
url,
verify=(not should_disable_connection_verify()),
timeout=10)
except Exception as ex: # pylint: disable=broad-except
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ def extension_name_completion_list(cmd, prefix, namespace, **kwargs): # pylint:

@Completer
def extension_name_from_index_completion_list(cmd, prefix, namespace, **kwargs): # pylint: disable=unused-argument
return get_index_extensions().keys()
return get_index_extensions(cli_ctx=cmd.cli_ctx).keys()
8 changes: 4 additions & 4 deletions src/azure-cli/azure/cli/command_modules/extension/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ def update_extension_cmd(cmd, extension_name, index_url=None, pip_extra_index_ur
pip_extra_index_urls=pip_extra_index_urls, pip_proxy=pip_proxy)


def list_available_extensions_cmd(index_url=None, show_details=False):
return list_available_extensions(index_url=index_url, show_details=show_details)
def list_available_extensions_cmd(cmd, index_url=None, show_details=False):
return list_available_extensions(index_url=index_url, show_details=show_details, cli_ctx=cmd.cli_ctx)


def list_versions_cmd(extension_name, index_url=None):
return list_versions(extension_name, index_url=index_url)
def list_versions_cmd(cmd, extension_name, index_url=None):
return list_versions(extension_name, index_url=index_url, cli_ctx=cmd.cli_ctx)