Skip to content

Commit

Permalink
Fix the dependencies of LPIPS metric (#2230)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>

(cherry picked from commit 1b16341)
  • Loading branch information
tanguymagne authored and Borda committed Dec 1, 2023
1 parent c665da1 commit 69a2d44
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 6 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Removed

-
- Removed `lpips` third-party package as dependency of `LearnedPerceptualImagePatchSimilarity` metric ([#2230](https://github.com/Lightning-AI/torchmetrics/pull/2230))


### Fixed
Expand Down
1 change: 0 additions & 1 deletion requirements/image.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,3 @@
scipy >1.0.0, <1.11.0
torchvision >=0.8, <0.17.0
torch-fidelity <=0.4.0 # bumping to allow install version from master, now used in testing
lpips <=0.1.4
1 change: 1 addition & 0 deletions requirements/image_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ pytorch-msssim ==1.0.0
sewar >=0.4.4, <=0.4.6
numpy <1.25.0
torch-fidelity @ git+https://github.com/toshas/torch-fidelity@master
lpips <=0.1.4
6 changes: 3 additions & 3 deletions src/torchmetrics/image/lpip.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ class LearnedPerceptualImagePatchSimilarity(Metric):
Both input image patches are expected to have shape ``(N, 3, H, W)``. The minimum size of `H, W` depends on the
chosen backbone (see `net_type` arg).
.. note:: using this metrics requires you to have ``lpips`` package installed. Either install
as ``pip install torchmetrics[image]`` or ``pip install lpips``
.. note:: using this metrics requires you to have ``torchvision`` package installed. Either install as
``pip install torchmetrics[image]`` or ``pip install torchvision``.
.. note:: this metric is not scriptable when using ``torch<1.8``. Please update your pytorch installation
if this is a issue.
Expand All @@ -71,7 +71,7 @@ class LearnedPerceptualImagePatchSimilarity(Metric):
Raises:
ModuleNotFoundError:
If ``lpips`` package is not installed
If ``torchvision`` package is not installed
ValueError:
If ``net_type`` is not one of ``"vgg"``, ``"alex"`` or ``"squeeze"``
ValueError:
Expand Down
5 changes: 4 additions & 1 deletion tests/unittests/image/test_lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torch import Tensor
from torchmetrics.functional.image.lpips import learned_perceptual_image_patch_similarity
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from torchmetrics.utilities.imports import _LPIPS_AVAILABLE, _TORCH_GREATER_EQUAL_1_9
from torchmetrics.utilities.imports import _LPIPS_AVAILABLE, _TORCH_GREATER_EQUAL_1_9, _TORCHVISION_AVAILABLE

from unittests.helpers import seed_all
from unittests.helpers.testers import MetricTester
Expand Down Expand Up @@ -48,6 +48,7 @@ def _compare_fn(img1: Tensor, img2: Tensor, net_type: str, normalize: bool = Fal
return res.sum()


@pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="test requires that torchvision is installed")
@pytest.mark.skipif(not _LPIPS_AVAILABLE, reason="test requires that lpips is installed")
class TestLPIPS(MetricTester):
"""Test class for `LearnedPerceptualImagePatchSimilarity` metric."""
Expand Down Expand Up @@ -107,6 +108,7 @@ def test_normalize_arg(normalize):
assert res == res2


@pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="test requires that torchvision is installed")
@pytest.mark.skipif(not _LPIPS_AVAILABLE, reason="test requires that lpips is installed")
def test_error_on_wrong_init():
"""Test class raises the expected errors."""
Expand All @@ -117,6 +119,7 @@ def test_error_on_wrong_init():
LearnedPerceptualImagePatchSimilarity(net_type="squeeze", reduction=None)


@pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="test requires that torchvision is installed")
@pytest.mark.skipif(not _LPIPS_AVAILABLE, reason="test requires that lpips is installed")
@pytest.mark.parametrize(
("inp1", "inp2"),
Expand Down

0 comments on commit 69a2d44

Please sign in to comment.