Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions homeassistant/util/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def install_package(package: str, upgrade: bool = True,
"""
# Not using 'import pip; pip.main([])' because it breaks the logger
with INSTALL_LOCK:
if check_package_exists(package):
if package_loadable(package):
return True

_LOGGER.info('Attempting install of %s', package)
Expand Down Expand Up @@ -61,8 +61,8 @@ def install_package(package: str, upgrade: bool = True,
return True


def check_package_exists(package: str) -> bool:
"""Check if a package is installed globally or in lib_dir.
def package_loadable(package: str) -> bool:
"""Check if a package is what will be loaded when we import it.

Returns True when the requirement is met.
Returns False when the package is not installed or doesn't meet req.
Expand All @@ -73,8 +73,14 @@ def check_package_exists(package: str) -> bool:
# This is a zip file
req = pkg_resources.Requirement.parse(urlparse(package).fragment)

env = pkg_resources.Environment()
return any(dist in req for dist in env[req.project_name])
for path in sys.path:
for dist in pkg_resources.find_distributions(path):
# If the project name is the same, it will be the one that is
# loaded when we import it.
if dist.project_name == req.project_name:
return dist in req

return False


async def async_get_user_site(deps_dir: str) -> str:
Expand Down
30 changes: 26 additions & 4 deletions tests/util/test_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def mock_sys():

@pytest.fixture
def mock_exists():
"""Mock check_package_exists."""
with patch('homeassistant.util.package.check_package_exists') as mock:
"""Mock package_loadable."""
with patch('homeassistant.util.package.package_loadable') as mock:
mock.return_value = False
yield mock

Expand Down Expand Up @@ -193,12 +193,12 @@ def test_install_constraint(
def test_check_package_global():
"""Test for an installed package."""
installed_package = list(pkg_resources.working_set)[0].project_name
assert package.check_package_exists(installed_package)
assert package.package_loadable(installed_package)


def test_check_package_zip():
"""Test for an installed zip package."""
assert not package.check_package_exists(TEST_ZIP_REQ)
assert not package.package_loadable(TEST_ZIP_REQ)


@asyncio.coroutine
Expand All @@ -217,3 +217,25 @@ def test_async_get_user_site(mock_env_copy):
stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.DEVNULL,
env=env)
assert ret == os.path.join(deps_dir, 'lib_dir')


def test_package_loadable_installed_twice():
"""Test that a package is loadable when installed twice.

If a package is installed twice, only the first version will be imported.
Test that package_loadable will only compare with the first package.
"""
v1 = pkg_resources.Distribution(project_name='hello', version='1.0.0')
v2 = pkg_resources.Distribution(project_name='hello', version='2.0.0')

with patch('pkg_resources.find_distributions', side_effect=[[v1]]):
assert not package.package_loadable('hello==2.0.0')

with patch('pkg_resources.find_distributions', side_effect=[[v1], [v2]]):
assert not package.package_loadable('hello==2.0.0')

with patch('pkg_resources.find_distributions', side_effect=[[v2], [v1]]):
assert package.package_loadable('hello==2.0.0')

with patch('pkg_resources.find_distributions', side_effect=[[v2]]):
assert package.package_loadable('hello==2.0.0')