diff --git a/CHANGELOG.md b/CHANGELOG.md index 9a60a827..1c8ae56c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - CI: update type/`mypy` check ([#288](https://github.com/Lightning-AI/utilities/pull/288)) +- Fixed parsing pre-release package versions in `RequirementCache` ([#292](https://github.com/Lightning-AI/utilities/pull/292)) ## [0.11.4] - 2024-07-15 diff --git a/src/lightning_utilities/__about__.py b/src/lightning_utilities/__about__.py index f49ae695..9e1a0c86 100644 --- a/src/lightning_utilities/__about__.py +++ b/src/lightning_utilities/__about__.py @@ -1,6 +1,6 @@ import time -__version__ = "0.11.5" +__version__ = "0.11.6" __author__ = "Lightning AI et al." __author_email__ = "pytorch@lightning.ai" __license__ = "Apache-2.0" diff --git a/src/lightning_utilities/core/imports.py b/src/lightning_utilities/core/imports.py index d8435966..a4f0fc1e 100644 --- a/src/lightning_utilities/core/imports.py +++ b/src/lightning_utilities/core/imports.py @@ -128,7 +128,7 @@ def _check_requirement(self) -> None: try: req = Requirement(self.requirement) pkg_version = Version(_version(req.name)) - self.available = req.specifier.contains(pkg_version) and ( + self.available = req.specifier.contains(pkg_version, prereleases=True) and ( not req.extras or self._check_extras_available(req) ) except (PackageNotFoundError, InvalidVersion) as ex: @@ -180,7 +180,7 @@ def _check_extras_available(self, requirement: Requirement) -> bool: try: extra_dist = distribution(extra_req.name) extra_installed_version = Version(extra_dist.version) - if extra_req.specifier and not extra_req.specifier.contains(extra_installed_version): + if extra_req.specifier and not extra_req.specifier.contains(extra_installed_version, prereleases=True): return False except importlib.metadata.PackageNotFoundError: return False diff --git a/tests/unittests/core/test_imports.py b/tests/unittests/core/test_imports.py index 2c406feb..4c8c9f00 100644 --- a/tests/unittests/core/test_imports.py +++ b/tests/unittests/core/test_imports.py @@ -99,6 +99,16 @@ def test_requirement_cache_with_extras(distribution_mock, version_mock, requirem assert not RequirementCache("jsonargparse[signatures]>=1.0.0") +@mock.patch("lightning_utilities.core.imports._version") +def test_requirement_cache_with_prerelease_package(version_mock): + version_mock.return_value = "0.11.0" + assert RequirementCache("transformer-engine>=0.11.0") + version_mock.return_value = "0.11.0.dev0+931b44f" + assert not RequirementCache("transformer-engine>=0.11.0") + version_mock.return_value = "1.10.0.dev0+931b44f" + assert RequirementCache("transformer-engine>=0.11.0") + + def test_module_available_cache(): assert RequirementCache(module="pytest") assert not RequirementCache(module="this_module_is_not_installed")