Skip to content
Merged
2 changes: 1 addition & 1 deletion .github/CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
/src/azure-cli/azure/cli/command_modules/container/ @samkreter
/src/azure-cli/azure/cli/command_modules/consumption/ @sandeepnl
/src/azure-cli/azure/cli/command_modules/dls/ @lewu-msft
/src/azure-cli/azure/cli/command_modules/extension/ @zikalino
/src/azure-cli/azure/cli/command_modules/extension/ @fengzhou-msft @haroldrandom
/src/azure-cli/azure/cli/command_modules/keyvault/ @bim-msft @fengzhou-msft
/src/azure-cli/azure/cli/command_modules/monitor/ @MyronFanQiu
/src/azure-cli/azure/cli/command_modules/natgateway/ @khannarheams @MyronFanQiu @haroldrandom
Expand Down
25 changes: 17 additions & 8 deletions src/azure-cli-core/azure/cli/core/extension/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
# pylint: disable=line-too-long

import os
import traceback
import json
import re
import sys
import pkginfo

from azure.cli.core._config import GLOBAL_CONFIG_DIR, ENV_VAR_PREFIX
Expand All @@ -18,10 +18,11 @@
az_config = CLIConfig(config_dir=GLOBAL_CONFIG_DIR, config_env_var_prefix=ENV_VAR_PREFIX)
_CUSTOM_EXT_DIR = az_config.get('extension', 'dir', None)
_DEV_EXTENSION_SOURCES = az_config.get('extension', 'dev_sources', None)
_CUSTOM_EXT_SYS_DIR = az_config.get('extension', 'sys_dir', None)
EXTENSIONS_DIR = os.path.expanduser(_CUSTOM_EXT_DIR) if _CUSTOM_EXT_DIR else os.path.join(GLOBAL_CONFIG_DIR,
'cliextensions')
DEV_EXTENSION_SOURCES = _DEV_EXTENSION_SOURCES.split(',') if _DEV_EXTENSION_SOURCES else []
EXTENSIONS_SYS_DIR = os.path.join(get_python_lib(), 'azure-cli-extensions') if sys.platform.startswith('linux') else ""
EXTENSIONS_SYS_DIR = os.path.expanduser(_CUSTOM_EXT_SYS_DIR) if _CUSTOM_EXT_SYS_DIR else os.path.join(get_python_lib(), 'azure-cli-extensions')
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to test if this also works on Windows and Mac.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On windows, users need to open the shell as administrator. I have added message when users encounter the permission error.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

run as admin only required when user set customized install dir right? if yes, then less concern since no change to existing behavoir

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, not affecting default extension installation.


EXTENSIONS_MOD_PREFIX = 'azext_'

Expand Down Expand Up @@ -134,10 +135,11 @@ def get_version(self):

def get_metadata(self):
from glob import glob
if not extension_exists(self.name):
return None
metadata = {}
ext_dir = self.path or get_extension_path(self.name)

if not ext_dir or not os.path.isdir(ext_dir):
return None
info_dirs = glob(os.path.join(ext_dir, self.name.replace('-', '_') + '-' + '*.dist-info'))

azext_metadata = WheelExtension.get_azext_metadata(ext_dir)
Expand Down Expand Up @@ -199,11 +201,10 @@ def get_version(self):
return self.metadata.get('version')

def get_metadata(self):

if not extension_exists(self.name):
return None
metadata = {}
ext_dir = self.path
if not ext_dir or not os.path.isdir(ext_dir):
return None
egg_info_dirs = [f for f in os.listdir(ext_dir) if f.endswith('.egg-info')]
azext_metadata = DevExtension.get_azext_metadata(ext_dir)
if azext_metadata:
Expand Down Expand Up @@ -284,8 +285,16 @@ def get_extension_modname(ext_name=None, ext_dir=None):


def get_extension_path(ext_name):
# This will return the path for a WHEEL extension if exists.
ext_sys_path = os.path.join(EXTENSIONS_SYS_DIR, ext_name)
ext_path = os.path.join(EXTENSIONS_DIR, ext_name)
return ext_path if os.path.isdir(ext_path) else (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add os.path.isdir(ext_path) check here, is it because previous code may case potential bug ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previous code use get_extension_path both before install when this path does not exist and when loading extensions. I splitted it to use the below build_extension_path for building the path before install, and to use this get_extension_path only for loading extensions. So we can add the isdir check here now.

ext_sys_path if os.path.isdir(ext_sys_path) else None)


def build_extension_path(ext_name, system=None):
# This will simply form the path for a WHEEL extension.
return os.path.join(EXTENSIONS_DIR, ext_name)
return os.path.join(EXTENSIONS_SYS_DIR, ext_name) if system else os.path.join(EXTENSIONS_DIR, ext_name)


def get_extensions(ext_type=None):
Expand Down
41 changes: 24 additions & 17 deletions src/azure-cli-core/azure/cli/core/extension/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
# pylint: disable=line-too-long

from collections import OrderedDict
import sys
import os
Expand All @@ -17,7 +19,7 @@
from pkg_resources import parse_version

from azure.cli.core.util import CLIError, reload_module
from azure.cli.core.extension import (extension_exists, get_extension_path, get_extensions, get_extension_modname,
from azure.cli.core.extension import (extension_exists, build_extension_path, get_extensions, get_extension_modname,
get_extension, ext_compat_with_cli,
EXT_METADATA_ISPREVIEW, EXT_METADATA_ISEXPERIMENTAL,
WheelExtension, DevExtension, ExtensionNotInstalledException, WHEEL_INFO_RE)
Expand All @@ -37,13 +39,14 @@
OUT_KEY_METADATA = 'metadata'
OUT_KEY_PREVIEW = 'preview'
OUT_KEY_EXPERIMENTAL = 'experimental'
OUT_KEY_PATH = 'path'

IS_WINDOWS = sys.platform.lower() in ['windows', 'win32']
LIST_FILE_PATH = os.path.join(os.sep, 'etc', 'apt', 'sources.list.d', 'azure-cli.list')
LSB_RELEASE_FILE = os.path.join(os.sep, 'etc', 'lsb-release')


def _run_pip(pip_exec_args):
def _run_pip(pip_exec_args, extension_path):
cmd = [sys.executable, '-m', 'pip'] + pip_exec_args + ['-vv', '--disable-pip-version-check', '--no-cache-dir']
logger.debug('Running: %s', cmd)
try:
Expand All @@ -53,6 +56,8 @@ def _run_pip(pip_exec_args):
except CalledProcessError as e:
logger.debug(e.output)
logger.debug(e)
if "PermissionError: [WinError 5]" in e.output:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's default install path after the change? suppose no change

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. no change for the default extension path.

logger.warning("You do not have the permission to add extensions in the target directory: %s. You may need to rerun on a shell as administrator.", os.path.split(extension_path)[0])
returncode = e.returncode
return returncode

Expand All @@ -79,7 +84,7 @@ def _validate_whl_extension(ext_file):
check_version_compatibility(azext_metadata)


def _add_whl_ext(cmd, source, ext_sha256=None, pip_extra_index_urls=None, pip_proxy=None): # pylint: disable=too-many-statements
def _add_whl_ext(cmd, source, ext_sha256=None, pip_extra_index_urls=None, pip_proxy=None, system=None): # pylint: disable=too-many-statements
cmd.cli_ctx.get_progress_controller().add(message='Analyzing')
if not source.endswith('.whl'):
raise ValueError('Unknown extension type. Only Python wheels are supported.')
Expand Down Expand Up @@ -135,7 +140,7 @@ def _add_whl_ext(cmd, source, ext_sha256=None, pip_extra_index_urls=None, pip_pr
check_distro_consistency()
cmd.cli_ctx.get_progress_controller().add(message='Installing')
# Install with pip
extension_path = get_extension_path(extension_name)
extension_path = build_extension_path(extension_name, system)
pip_args = ['install', '--target', extension_path, ext_file]

if pip_proxy:
Expand All @@ -146,7 +151,7 @@ def _add_whl_ext(cmd, source, ext_sha256=None, pip_extra_index_urls=None, pip_pr

logger.debug('Executing pip with args: %s', pip_args)
with HomebrewPipPatch():
pip_status_code = _run_pip(pip_args)
pip_status_code = _run_pip(pip_args, extension_path)
if pip_status_code > 0:
logger.debug('Pip failed so deleting anything we might have installed at %s', extension_path)
shutil.rmtree(extension_path, ignore_errors=True)
Expand All @@ -168,12 +173,12 @@ def is_valid_sha256sum(a_file, expected_sum):
return expected_sum == computed_hash, computed_hash


def _augment_telemetry_with_ext_info(extension_name):
def _augment_telemetry_with_ext_info(extension_name, ext=None):
# The extension must be available before calling this otherwise we can't get the version from metadata
if not extension_name:
return
try:
ext = get_extension(extension_name)
ext = ext or get_extension(extension_name)
ext_version = ext.version
set_extension_management_detail(extension_name, ext_version)
except Exception: # nopa pylint: disable=broad-except
Expand All @@ -200,7 +205,7 @@ def check_version_compatibility(azext_metadata):


def add_extension(cmd, source=None, extension_name=None, index_url=None, yes=None, # pylint: disable=unused-argument
pip_extra_index_urls=None, pip_proxy=None):
pip_extra_index_urls=None, pip_proxy=None, system=None):
ext_sha256 = None
if extension_name:
cmd.cli_ctx.get_progress_controller().add(message='Searching')
Expand All @@ -220,13 +225,14 @@ def add_extension(cmd, source=None, extension_name=None, index_url=None, yes=Non
logger.debug(err)
raise CLIError("No matching extensions for '{}'. Use --debug for more information.".format(extension_name))
extension_name = _add_whl_ext(cmd=cmd, source=source, ext_sha256=ext_sha256,
pip_extra_index_urls=pip_extra_index_urls, pip_proxy=pip_proxy)
_augment_telemetry_with_ext_info(extension_name)
pip_extra_index_urls=pip_extra_index_urls, pip_proxy=pip_proxy, system=system)
try:
if extension_name and get_extension(extension_name).experimental:
ext = get_extension(extension_name)
_augment_telemetry_with_ext_info(extension_name, ext)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move telemetry part into the try/except, so if get_extension throw exception, it will not be executed, right ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. get_extension was also called in _augment_telemetry_with_ext_info in previous code and it will be catched inside. And I think it's the same flow.

if extension_name and ext.experimental:
logger.warning("The installed extension '%s' is experimental and not covered by customer support. "
"Please use with discretion.", extension_name)
elif extension_name and get_extension(extension_name).preview:
elif extension_name and ext.preview:
logger.warning("The installed extension '%s' is in preview.", extension_name)
except ExtensionNotInstalledException:
pass
Expand All @@ -245,15 +251,15 @@ def log_err(func, path, exc_info):
"Extension '{name}' was installed in development mode. Remove using "
"`azdev extension remove {name}`".format(name=extension_name))
# We call this just before we remove the extension so we can get the metadata before it is gone
_augment_telemetry_with_ext_info(extension_name)
shutil.rmtree(get_extension_path(extension_name), onerror=log_err)
_augment_telemetry_with_ext_info(extension_name, ext)
shutil.rmtree(ext.path, onerror=log_err)
except ExtensionNotInstalledException as e:
raise CLIError(e)


def list_extensions():
return [{OUT_KEY_NAME: ext.name, OUT_KEY_VERSION: ext.version, OUT_KEY_TYPE: ext.ext_type,
OUT_KEY_PREVIEW: ext.preview, OUT_KEY_EXPERIMENTAL: ext.experimental}
OUT_KEY_PREVIEW: ext.preview, OUT_KEY_EXPERIMENTAL: ext.experimental, OUT_KEY_PATH: ext.path}
for ext in get_extensions()]


Expand All @@ -263,7 +269,8 @@ def show_extension(extension_name):
return {OUT_KEY_NAME: extension.name,
OUT_KEY_VERSION: extension.version,
OUT_KEY_TYPE: extension.ext_type,
OUT_KEY_METADATA: extension.metadata}
OUT_KEY_METADATA: extension.metadata,
OUT_KEY_PATH: extension.path}
except ExtensionNotInstalledException as e:
raise CLIError(e)

Expand All @@ -279,7 +286,7 @@ def update_extension(cmd, extension_name, index_url=None, pip_extra_index_urls=N
raise CLIError("No updates available for '{}'. Use --debug for more information.".format(extension_name))
# Copy current version of extension to tmp directory in case we need to restore it after a failed install.
backup_dir = os.path.join(tempfile.mkdtemp(), extension_name)
extension_path = get_extension_path(extension_name)
extension_path = ext.path
logger.debug('Backing up the current extension: %s to %s', extension_path, backup_dir)
shutil.copytree(extension_path, backup_dir)
# Remove current version of the extension
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
import mock

from azure.cli.core.util import CLIError
from azure.cli.core.extension import build_extension_path
from azure.cli.core.extension.operations import (list_extensions, add_extension, show_extension,
remove_extension, update_extension,
list_available_extensions, OUT_KEY_NAME, OUT_KEY_VERSION, OUT_KEY_METADATA)
list_available_extensions, OUT_KEY_NAME, OUT_KEY_VERSION,
OUT_KEY_METADATA, OUT_KEY_PATH)
from azure.cli.core.extension._resolve import NoExtensionCandidatesError
from azure.cli.core.mock import DummyCli

Expand All @@ -39,13 +41,18 @@ class TestExtensionCommands(unittest.TestCase):

def setUp(self):
self.ext_dir = tempfile.mkdtemp()
self.patcher = mock.patch('azure.cli.core.extension.EXTENSIONS_DIR', self.ext_dir)
self.patcher.start()
self.ext_sys_dir = tempfile.mkdtemp()
self.patchers = [mock.patch('azure.cli.core.extension.EXTENSIONS_DIR', self.ext_dir),
mock.patch('azure.cli.core.extension.EXTENSIONS_SYS_DIR', self.ext_sys_dir)]
for patcher in self.patchers:
patcher.start()
self.cmd = self._setup_cmd()

def tearDown(self):
self.patcher.stop()
for patcher in self.patchers:
patcher.stop()
shutil.rmtree(self.ext_dir, ignore_errors=True)
shutil.rmtree(self.ext_sys_dir, ignore_errors=True)

def test_no_extensions_dir(self):
shutil.rmtree(self.ext_dir)
Expand All @@ -66,6 +73,34 @@ def test_add_list_show_remove_extension(self):
num_exts = len(list_extensions())
self.assertEqual(num_exts, 0)

def test_add_list_show_remove_system_extension(self):
add_extension(cmd=self.cmd, source=MY_EXT_SOURCE, system=True)
actual = list_extensions()
self.assertEqual(len(actual), 1)
ext = show_extension(MY_EXT_NAME)
self.assertEqual(ext[OUT_KEY_NAME], MY_EXT_NAME)
remove_extension(MY_EXT_NAME)
num_exts = len(list_extensions())
self.assertEqual(num_exts, 0)

def test_add_list_show_remove_user_system_extensions(self):
add_extension(cmd=self.cmd, source=MY_EXT_SOURCE)
add_extension(cmd=self.cmd, source=MY_SECOND_EXT_SOURCE_DASHES, system=True)
actual = list_extensions()
self.assertEqual(len(actual), 2)
ext = show_extension(MY_EXT_NAME)
self.assertEqual(ext[OUT_KEY_NAME], MY_EXT_NAME)
self.assertEqual(ext[OUT_KEY_PATH], build_extension_path(MY_EXT_NAME))
second_ext = show_extension(MY_SECOND_EXT_NAME_DASHES)
self.assertEqual(second_ext[OUT_KEY_NAME], MY_SECOND_EXT_NAME_DASHES)
self.assertEqual(second_ext[OUT_KEY_PATH], build_extension_path(MY_SECOND_EXT_NAME_DASHES, system=True))
remove_extension(MY_EXT_NAME)
num_exts = len(list_extensions())
self.assertEqual(num_exts, 1)
remove_extension(MY_SECOND_EXT_NAME_DASHES)
num_exts = len(list_extensions())
self.assertEqual(num_exts, 0)

def test_add_list_show_remove_extension_with_dashes(self):
add_extension(cmd=self.cmd, source=MY_SECOND_EXT_SOURCE_DASHES)
actual = list_extensions()
Expand All @@ -85,6 +120,13 @@ def test_add_extension_twice(self):
with self.assertRaises(CLIError):
add_extension(cmd=self.cmd, source=MY_EXT_SOURCE)

def test_add_same_extension_user_system(self):
add_extension(cmd=self.cmd, source=MY_EXT_SOURCE)
num_exts = len(list_extensions())
self.assertEqual(num_exts, 1)
with self.assertRaises(CLIError):
add_extension(cmd=self.cmd, source=MY_EXT_SOURCE, system=True)

def test_add_extension_invalid(self):
with self.assertRaises(ValueError):
add_extension(cmd=self.cmd, source=MY_BAD_EXT_SOURCE)
Expand Down
Binary file not shown.
Loading