Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bump: support torch>=2.0 #2671

Merged
merged 76 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
4ea6de2
ci
Borda Aug 2, 2024
8ac5de3
req
Borda Aug 2, 2024
9b15f98
torch
Borda Aug 2, 2024
5ef5c67
torchvision
Borda Aug 2, 2024
d566521
torchaudio
Borda Aug 2, 2024
649bdd4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 2, 2024
38db864
fix
Borda Aug 2, 2024
f068e7f
cudnn9
Borda Aug 2, 2024
25774ea
Merge branch 'master' into bump/torch
Borda Aug 3, 2024
401a78c
Merge branch 'master' into bump/torch
Borda Aug 5, 2024
a4c0e31
images
Borda Aug 5, 2024
a74ad43
_TORCHVISION_AVAILABLE
Borda Aug 5, 2024
529f7cf
Merge branch 'master' into bump/torch
Borda Aug 5, 2024
5ff8c78
ver
Borda Aug 5, 2024
0b9ec14
Merge branch 'bump/torch' of https://github.com/PyTorchLightning/metr…
Borda Aug 5, 2024
0510bc4
ver
Borda Aug 5, 2024
7963e25
ver
Borda Aug 5, 2024
9e4a79f
tqdm
Borda Aug 5, 2024
fdb6195
xfail
Borda Aug 5, 2024
1da7f51
bump pystoi
Borda Aug 5, 2024
7d967f4
Merge branch 'master' into bump/torch
Borda Aug 5, 2024
ed34404
Merge branch 'master' into bump/torch
Borda Aug 7, 2024
f3587f1
Merge branch 'master' into bump/torch
Borda Aug 8, 2024
765ea32
Merge branch 'master' into bump/torch
Borda Aug 14, 2024
84a7c9d
Merge branch 'master' into bump/torch
Borda Aug 27, 2024
51d7e7b
pyGithub
Borda Aug 27, 2024
3e1c93c
Update requirements/_tests.txt
SkafteNicki Aug 28, 2024
b9efe07
Merge branch 'master' into bump/torch
SkafteNicki Aug 28, 2024
ccee20c
changelog
SkafteNicki Aug 28, 2024
7700c41
skip kendal on mac
SkafteNicki Aug 28, 2024
c860b49
nltk
Borda Aug 28, 2024
969be58
pygithub
Borda Aug 28, 2024
51804cf
tqdm
Borda Aug 28, 2024
ecd32ac
Empty-Commit
Borda Aug 28, 2024
6c7037a
doctest
Borda Aug 28, 2024
83d642c
librosa
Borda Aug 29, 2024
6fee0b9
Merge branch 'master' into bump/torch
Borda Aug 30, 2024
c9f5957
Merge branch 'master' into bump/torch
Borda Aug 30, 2024
097a480
stoi
Borda Aug 30, 2024
089d644
atol = 1e-7
Borda Aug 30, 2024
75c2a9d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 30, 2024
66008e3
scikit-learn
Borda Aug 30, 2024
a4487ff
Merge branch 'bump/torch' of https://github.com/PyTorchLightning/metr…
Borda Aug 30, 2024
6c6ece8
> =
Borda Aug 30, 2024
51567b2
hamming
Borda Aug 30, 2024
665d50f
atol=1e-3
Borda Aug 30, 2024
7c01448
vals
Borda Aug 30, 2024
3e4ca7d
atol=5e-3
Borda Aug 31, 2024
b80a9ec
rtol=1e-2
Borda Sep 1, 2024
de43778
Merge branch 'master' into bump/torch
Borda Sep 2, 2024
8e03554
Merge branch 'master' into bump/torch
Borda Sep 2, 2024
cb95add
Merge branch 'master' into bump/torch
Borda Sep 2, 2024
043ceef
Merge branch 'master' into bump/torch
Borda Sep 3, 2024
d6787fb
more...
Borda Sep 3, 2024
39d1e24
skip
Borda Sep 3, 2024
397ecd1
smaller array to fix torch.unique case
SkafteNicki Sep 3, 2024
71cb200
95
Borda Sep 3, 2024
3a298bf
dython ~=0.7.7
Borda Sep 3, 2024
69f282f
dython ==0.7.*
Borda Sep 3, 2024
7c80f5f
pandas >1.4.0
Borda Sep 3, 2024
47b45c6
matplotlib >=3.5.0
Borda Sep 3, 2024
5661434
Apply suggestions from code review
Borda Sep 3, 2024
70fa7ec
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 3, 2024
fcbfbb5
Apply suggestions from code review
Borda Sep 3, 2024
c034159
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 3, 2024
cd56411
Apply suggestions from code review
Borda Sep 3, 2024
7387290
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 3, 2024
442d514
fix indention
SkafteNicki Sep 4, 2024
a0b3488
more fixes
SkafteNicki Sep 4, 2024
7862a67
more fixes
SkafteNicki Sep 4, 2024
2119299
more fixes
SkafteNicki Sep 4, 2024
33637ba
dython ==0.7.7
Borda Sep 4, 2024
dafd744
matplotlib >=3.6.0
Borda Sep 4, 2024
cd868ca
Merge branch 'master' into bump/torch
SkafteNicki Sep 7, 2024
7528c39
Merge branch 'master' into bump/torch
Borda Sep 9, 2024
6ccb5bc
Merge branch 'master' into bump/torch
Borda Oct 21, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .azure/gpu-integrations.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ jobs:
- job: integrate_GPU
strategy:
matrix:
"torch | 1.x":
docker-image: "pytorchlightning/torchmetrics:ubuntu22.04-cuda11.8.0-py3.9-torch1.13"
torch-ver: "1.13"
"torch | 2.0":
docker-image: "pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime"
torch-ver: "2.0"
requires: "oldest"
"torch | 2.x":
docker-image: "pytorch/pytorch:2.4.0-cuda12.1-cudnn9-runtime"
Expand Down
15 changes: 6 additions & 9 deletions .azure/gpu-unittests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,10 @@ jobs:
- job: unitest_GPU
strategy:
matrix:
"PyTorch | 1.10 oldest":
"PyTorch | 2.0 oldest":
# Torch does not have build wheels with old Torch versions for newer CUDA
docker-image: "ubuntu20.04-cuda11.3.1-py3.9-torch1.10"
torch-ver: "1.10"
"PyTorch | 1.X LTS":
docker-image: "ubuntu22.04-cuda11.8.0-py3.9-torch1.13"
torch-ver: "1.13"
docker-image: "ubuntu22.04-cuda11.8.0-py3.10-torch2.0"
torch-ver: "2.0"
"PyTorch | 2.X stable":
docker-image: "ubuntu22.04-cuda12.1.1-py3.11-torch2.4"
torch-ver: "2.4"
Expand Down Expand Up @@ -123,7 +120,7 @@ jobs:

- bash: |
python .github/assistant.py set-oldest-versions
condition: eq(variables['torch-ver'], '1.10')
condition: eq(variables['torch-ver'], '2.0')
displayName: "Setting oldest versions"

- bash: |
Expand Down Expand Up @@ -191,7 +188,7 @@ jobs:
workingDirectory: "tests/"
# skip for PR if there is nothing to test, note that outside PR there is default 'unittests'
condition: and(succeeded(), ne(variables['TEST_DIRS'], ''))
timeoutInMinutes: "90"
timeoutInMinutes: "95"
displayName: "UnitTesting common"

- bash: |
Expand All @@ -203,7 +200,7 @@ jobs:
workingDirectory: "tests/"
# skip for PR if there is nothing to test, note that outside PR there is default 'unittests'
condition: and(succeeded(), ne(variables['TEST_DIRS'], ''))
timeoutInMinutes: "90"
timeoutInMinutes: "95"
displayName: "UnitTesting DDP"

- bash: |
Expand Down
7 changes: 0 additions & 7 deletions .github/workflows/ci-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,29 +34,22 @@ jobs:
os: ["ubuntu-20.04"]
python-version: ["3.9"]
pytorch-version:
- "1.10.2"
- "1.11.0"
- "1.12.1"
- "1.13.1"
- "2.0.1"
- "2.1.2"
- "2.2.2"
- "2.3.1"
- "2.4.0"
include:
# cover additional python and PT combinations
- { os: "ubuntu-22.04", python-version: "3.8", pytorch-version: "1.13.1" }
- { os: "ubuntu-22.04", python-version: "3.10", pytorch-version: "2.0.1" }
- { os: "ubuntu-22.04", python-version: "3.10", pytorch-version: "2.2.2" }
- { os: "ubuntu-22.04", python-version: "3.11", pytorch-version: "2.3.1" }
# standard mac machine, not the M1
- { os: "macOS-13", python-version: "3.8", pytorch-version: "1.13.1" }
- { os: "macOS-13", python-version: "3.10", pytorch-version: "2.0.1" }
# using the ARM based M1 machine
- { os: "macOS-14", python-version: "3.10", pytorch-version: "2.0.1" }
- { os: "macOS-14", python-version: "3.11", pytorch-version: "2.4.0" }
# some windows
- { os: "windows-2022", python-version: "3.8", pytorch-version: "1.13.1" }
- { os: "windows-2022", python-version: "3.10", pytorch-version: "2.0.1" }
- { os: "windows-2022", python-version: "3.11", pytorch-version: "2.4.0" }
# Future released version
Expand Down
3 changes: 0 additions & 3 deletions .github/workflows/docker-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,6 @@ jobs:
include:
# These are the base images for PL release docker images,
# so include at least all the combinations in release-dockers.yml.
- { python: "3.9", pytorch: "1.10", cuda: "11.3.1", ubuntu: "20.04" }
#- { python: "3.9", pytorch: "1.11", cuda: "11.8.0", ubuntu: "22.04" }
- { python: "3.9", pytorch: "1.13", cuda: "11.8.0", ubuntu: "22.04" }
- { python: "3.10", pytorch: "2.2", cuda: "12.1.1", ubuntu: "22.04" }
- { python: "3.11", pytorch: "2.2", cuda: "12.1.1", ubuntu: "22.04" }
- { python: "3.11", pytorch: "2.3", cuda: "12.1.1", ubuntu: "22.04" }
Expand Down
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Removed

-
- Changed minimum supported Pytorch version to 2.0 ([#2671](https://github.com/Lightning-AI/torchmetrics/pull/2671))


### Fixed
Expand Down
1 change: 1 addition & 0 deletions requirements/_tests.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

codecov ==2.1.13
coverage ==7.6.*
codecov ==2.1.13
pytest ==8.3.*
Expand Down
4 changes: 2 additions & 2 deletions requirements/audio.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
# this need to be the same as used inside speechmetrics
pesq >=0.0.4, <0.0.5
pystoi >=0.4.0, <0.5.0
torchaudio >=0.10.0, <2.5.0
torchaudio >=2.0.1, <2.5.0
gammatone >=1.0.0, <1.1.0
librosa >=0.9.0, <0.11.0
librosa >=0.10.0, <0.11.0
onnxruntime >=1.12.0, <1.20 # installing onnxruntime_gpu-gpu failed on macos
requests >=2.19.0, <2.33.0
2 changes: 1 addition & 1 deletion requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@

numpy >1.20.0, <2.0 # strict, for compatibility reasons
packaging >17.1
torch >=1.10.0, <2.5.0
torch >=2.0.0, <2.5.0
typing-extensions; python_version < '3.9'
lightning-utilities >=0.8.0, <0.12.0
2 changes: 1 addition & 1 deletion requirements/detection.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

torchvision >=0.8, <0.20.0
torchvision >=0.15.1, <0.20.0
pycocotools >2.0.0, <2.1.0
2 changes: 1 addition & 1 deletion requirements/image.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

scipy >1.0.0, <1.15.0
torchvision >=0.8, <0.20.0
torchvision >=0.15.1, <0.20.0
torch-fidelity <=0.4.0 # bumping to allow install version from master, now used in testing
11 changes: 9 additions & 2 deletions src/torchmetrics/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,17 @@
_ONNXRUNTIME_AVAILABLE,
_PESQ_AVAILABLE,
_PYSTOI_AVAILABLE,
_SCIPI_AVAILABLE,
_TORCHAUDIO_AVAILABLE,
_TORCHAUDIO_GREATER_EQUAL_0_10,
)

if _SCIPI_AVAILABLE:
import scipy.signal

# back compatibility patch due to SMRMpy using scipy.signal.hamming
if not hasattr(scipy.signal, "hamming"):
scipy.signal.hamming = scipy.signal.windows.hamming

__all__ = [
"PermutationInvariantTraining",
"ScaleInvariantSignalDistortionRatio",
Expand All @@ -52,7 +59,7 @@

__all__ += ["ShortTimeObjectiveIntelligibility"]

if _GAMMATONE_AVAILABLE and _TORCHAUDIO_AVAILABLE and _TORCHAUDIO_GREATER_EQUAL_0_10:
if _GAMMATONE_AVAILABLE and _TORCHAUDIO_AVAILABLE:
from torchmetrics.audio.srmr import SpeechReverberationModulationEnergyRatio

__all__ += ["SpeechReverberationModulationEnergyRatio"]
Expand Down
5 changes: 2 additions & 3 deletions src/torchmetrics/audio/srmr.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,10 @@
_GAMMATONE_AVAILABLE,
_MATPLOTLIB_AVAILABLE,
_TORCHAUDIO_AVAILABLE,
_TORCHAUDIO_GREATER_EQUAL_0_10,
)
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not all([_GAMMATONE_AVAILABLE, _TORCHAUDIO_AVAILABLE, _TORCHAUDIO_GREATER_EQUAL_0_10]):
if not all([_GAMMATONE_AVAILABLE, _TORCHAUDIO_AVAILABLE]):
__doctest_skip__ = ["SpeechReverberationModulationEnergyRatio", "SpeechReverberationModulationEnergyRatio.plot"]
elif not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["SpeechReverberationModulationEnergyRatio.plot"]
Expand Down Expand Up @@ -105,7 +104,7 @@ def __init__(
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
if not _TORCHAUDIO_AVAILABLE or not _TORCHAUDIO_GREATER_EQUAL_0_10 or not _GAMMATONE_AVAILABLE:
if not _TORCHAUDIO_AVAILABLE or not _GAMMATONE_AVAILABLE:
raise ModuleNotFoundError(
"speech_reverberation_modulation_energy_ratio requires you to have `gammatone` and"
" `torchaudio>=0.10` installed. Either install as ``pip install torchmetrics[audio]`` or "
Expand Down
23 changes: 11 additions & 12 deletions src/torchmetrics/detection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.detection.panoptic_qualities import ModifiedPanopticQuality, PanopticQuality
from torchmetrics.utilities.imports import (
_TORCHVISION_GREATER_EQUAL_0_8,
_TORCHVISION_GREATER_EQUAL_0_13,
)
from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE

__all__ = ["ModifiedPanopticQuality", "PanopticQuality"]

if _TORCHVISION_GREATER_EQUAL_0_8:
if _TORCHVISION_AVAILABLE:
from torchmetrics.detection.ciou import CompleteIntersectionOverUnion
from torchmetrics.detection.diou import DistanceIntersectionOverUnion
from torchmetrics.detection.giou import GeneralizedIntersectionOverUnion
from torchmetrics.detection.iou import IntersectionOverUnion
from torchmetrics.detection.mean_ap import MeanAveragePrecision

__all__ += ["MeanAveragePrecision", "GeneralizedIntersectionOverUnion", "IntersectionOverUnion"]

if _TORCHVISION_GREATER_EQUAL_0_13:
from torchmetrics.detection.ciou import CompleteIntersectionOverUnion
from torchmetrics.detection.diou import DistanceIntersectionOverUnion

__all__ += ["CompleteIntersectionOverUnion", "DistanceIntersectionOverUnion"]
__all__ += [
"MeanAveragePrecision",
"GeneralizedIntersectionOverUnion",
"IntersectionOverUnion",
"CompleteIntersectionOverUnion",
"DistanceIntersectionOverUnion",
]
9 changes: 0 additions & 9 deletions src/torchmetrics/detection/_deprecated.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,8 @@
from typing import Any, Collection

from torchmetrics.detection import ModifiedPanopticQuality, PanopticQuality
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_12
from torchmetrics.utilities.prints import _deprecated_root_import_class

if not _TORCH_GREATER_EQUAL_1_12:
__doctest_skip__ = [
"_PanopticQuality",
"_PanopticQuality.*",
"_ModifiedPanopticQuality",
"_ModifiedPanopticQuality.*",
]


class _ModifiedPanopticQuality(ModifiedPanopticQuality):
"""Wrapper for deprecated import.
Expand Down
10 changes: 5 additions & 5 deletions src/torchmetrics/detection/_mean_ap.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
from torchmetrics.detection.helpers import _fix_empty_tensors, _input_validator
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import _cumsum
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _PYCOCOTOOLS_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_8
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _PYCOCOTOOLS_AVAILABLE, _TORCHVISION_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["MeanAveragePrecision.plot"]

if not _TORCHVISION_GREATER_EQUAL_0_8 or not _PYCOCOTOOLS_AVAILABLE:
if not _TORCHVISION_AVAILABLE or not _PYCOCOTOOLS_AVAILABLE:
__doctest_skip__ = ["MeanAveragePrecision.plot", "MeanAveragePrecision"]

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -327,10 +327,10 @@ def __init__(
"`MAP` metric requires that `pycocotools` installed."
" Please install with `pip install pycocotools` or `pip install torchmetrics[detection]`"
)
Borda marked this conversation as resolved.
Show resolved Hide resolved
if not _TORCHVISION_GREATER_EQUAL_0_8:
if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError(
"`MeanAveragePrecision` metric requires that `torchvision` version 0.8.0 or newer is installed."
" Please install with `pip install torchvision>=0.8` or `pip install torchmetrics[detection]`."
"`MeanAveragePrecision` metric requires that `torchvision` is installed."
" Please install with `pip install torchmetrics[detection]`."
)

allowed_box_formats = ("xyxy", "xywh", "cxcywh")
Expand Down
10 changes: 5 additions & 5 deletions src/torchmetrics/detection/ciou.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

from torchmetrics.detection.iou import IntersectionOverUnion
from torchmetrics.functional.detection.ciou import _ciou_compute, _ciou_update
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_13
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCHVISION_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _TORCHVISION_GREATER_EQUAL_0_13:
if not _TORCHVISION_AVAILABLE:
__doctest_skip__ = ["CompleteIntersectionOverUnion", "CompleteIntersectionOverUnion.plot"]
elif not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["CompleteIntersectionOverUnion.plot"]
Expand Down Expand Up @@ -110,10 +110,10 @@ def __init__(
respect_labels: bool = True,
**kwargs: Any,
) -> None:
Borda marked this conversation as resolved.
Show resolved Hide resolved
if not _TORCHVISION_GREATER_EQUAL_0_13:
if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError(
f"Metric `{self._iou_type.upper()}` requires that `torchvision` version 0.13.0 or newer is installed."
" Please install with `pip install torchvision>=0.13` or `pip install torchmetrics[detection]`."
f"Metric `{self._iou_type.upper()}` requires that `torchvision` is installed."
" Please install with `pip install torchmetrics[detection]`."
)
super().__init__(box_format, iou_threshold, class_metrics, respect_labels, **kwargs)

Expand Down
10 changes: 5 additions & 5 deletions src/torchmetrics/detection/diou.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

from torchmetrics.detection.iou import IntersectionOverUnion
from torchmetrics.functional.detection.diou import _diou_compute, _diou_update
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_13
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCHVISION_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _TORCHVISION_GREATER_EQUAL_0_13:
if not _TORCHVISION_AVAILABLE:
__doctest_skip__ = ["DistanceIntersectionOverUnion", "DistanceIntersectionOverUnion.plot"]
elif not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["DistanceIntersectionOverUnion.plot"]
Expand Down Expand Up @@ -110,10 +110,10 @@ def __init__(
respect_labels: bool = True,
**kwargs: Any,
) -> None:
Borda marked this conversation as resolved.
Show resolved Hide resolved
if not _TORCHVISION_GREATER_EQUAL_0_13:
if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError(
f"Metric `{self._iou_type.upper()}` requires that `torchvision` version 0.13.0 or newer is installed."
" Please install with `pip install torchvision>=0.13` or `pip install torchmetrics[detection]`."
f"Metric `{self._iou_type.upper()}` requires that `torchvision` is installed."
" Please install with `pip install torchmetrics[detection]`."
)
super().__init__(box_format, iou_threshold, class_metrics, respect_labels, **kwargs)

Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/detection/giou.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

from torchmetrics.detection.iou import IntersectionOverUnion
from torchmetrics.functional.detection.giou import _giou_compute, _giou_update
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_8
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCHVISION_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _TORCHVISION_GREATER_EQUAL_0_8:
if not _TORCHVISION_AVAILABLE:
__doctest_skip__ = ["GeneralizedIntersectionOverUnion", "GeneralizedIntersectionOverUnion.plot"]
elif not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["GeneralizedIntersectionOverUnion.plot"]
Expand Down
10 changes: 5 additions & 5 deletions src/torchmetrics/detection/iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
from torchmetrics.functional.detection.iou import _iou_compute, _iou_update
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCHVISION_GREATER_EQUAL_0_8
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCHVISION_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _TORCHVISION_GREATER_EQUAL_0_8:
if not _TORCHVISION_AVAILABLE:
__doctest_skip__ = ["IntersectionOverUnion", "IntersectionOverUnion.plot"]
elif not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["IntersectionOverUnion.plot"]
Expand Down Expand Up @@ -146,10 +146,10 @@ def __init__(
) -> None:
super().__init__(**kwargs)

Borda marked this conversation as resolved.
Show resolved Hide resolved
if not _TORCHVISION_GREATER_EQUAL_0_8:
if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError(
f"Metric `{self._iou_type.upper()}` requires that `torchvision` version 0.8.0 or newer is installed."
" Please install with `pip install torchvision>=0.8` or `pip install torchmetrics[detection]`."
f"Metric `{self._iou_type.upper()}` requires that `torchvision` is installed."
" Please install with `pip install torchmetrics[detection]`."
)

allowed_box_formats = ("xyxy", "xywh", "cxcywh")
Expand Down
Loading
Loading