From 96ceda0e759d57cfce8346e434c352f63f2683e1 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Wed, 11 Sep 2024 14:30:49 +0200 Subject: [PATCH] fix: compatibility audio do with new `scipy` (#2733) * compatibility audio do with new `scipy` * smaller array to fix torch.unique case --------- Co-authored-by: Nicki Skafte Detlefsen --- CHANGELOG.md | 3 +++ src/torchmetrics/audio/__init__.py | 8 ++++++++ src/torchmetrics/functional/audio/__init__.py | 8 ++++++++ src/torchmetrics/functional/nominal/__init__.py | 9 +++++++++ src/torchmetrics/nominal/__init__.py | 9 +++++++++ src/torchmetrics/utilities/imports.py | 1 + tests/unittests/classification/test_stat_scores.py | 4 ++-- 7 files changed, 40 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4bac68f736f..0fc6c936492 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -45,6 +45,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Correct the padding related calculation errors in SSIM ([#2721](https://github.com/Lightning-AI/torchmetrics/pull/2721)) +- Fixed compatibility of audio domain with new `scipy` ([#2733](https://github.com/Lightning-AI/torchmetrics/pull/2733)) + + - Fixed how `prefix`/`postfix` works in `MultitaskWrapper` ([#2722](https://github.com/Lightning-AI/torchmetrics/pull/2722)) diff --git a/src/torchmetrics/audio/__init__.py b/src/torchmetrics/audio/__init__.py index 6d21902b13e..05ad2dd2729 100644 --- a/src/torchmetrics/audio/__init__.py +++ b/src/torchmetrics/audio/__init__.py @@ -28,10 +28,18 @@ _ONNXRUNTIME_AVAILABLE, _PESQ_AVAILABLE, _PYSTOI_AVAILABLE, + _SCIPI_AVAILABLE, _TORCHAUDIO_AVAILABLE, _TORCHAUDIO_GREATER_EQUAL_0_10, ) +if _SCIPI_AVAILABLE: + import scipy.signal + + # back compatibility patch due to SMRMpy using scipy.signal.hamming + if not hasattr(scipy.signal, "hamming"): + scipy.signal.hamming = scipy.signal.windows.hamming + __all__ = [ "PermutationInvariantTraining", "ScaleInvariantSignalDistortionRatio", diff --git a/src/torchmetrics/functional/audio/__init__.py b/src/torchmetrics/functional/audio/__init__.py index c8a8b5a4bcc..0d4fb1e9a88 100644 --- a/src/torchmetrics/functional/audio/__init__.py +++ b/src/torchmetrics/functional/audio/__init__.py @@ -28,10 +28,18 @@ _ONNXRUNTIME_AVAILABLE, _PESQ_AVAILABLE, _PYSTOI_AVAILABLE, + _SCIPI_AVAILABLE, _TORCHAUDIO_AVAILABLE, _TORCHAUDIO_GREATER_EQUAL_0_10, ) +if _SCIPI_AVAILABLE: + import scipy.signal + + # back compatibility patch due to SMRMpy using scipy.signal.hamming + if not hasattr(scipy.signal, "hamming"): + scipy.signal.hamming = scipy.signal.windows.hamming + __all__ = [ "permutation_invariant_training", "pit_permutate", diff --git a/src/torchmetrics/functional/nominal/__init__.py b/src/torchmetrics/functional/nominal/__init__.py index f29dd9302f0..54ef36a8a90 100644 --- a/src/torchmetrics/functional/nominal/__init__.py +++ b/src/torchmetrics/functional/nominal/__init__.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from torchmetrics.functional.nominal.cramers import cramers_v, cramers_v_matrix from torchmetrics.functional.nominal.fleiss_kappa import fleiss_kappa from torchmetrics.functional.nominal.pearson import ( @@ -19,6 +20,14 @@ ) from torchmetrics.functional.nominal.theils_u import theils_u, theils_u_matrix from torchmetrics.functional.nominal.tschuprows import tschuprows_t, tschuprows_t_matrix +from torchmetrics.utilities.imports import _SCIPI_AVAILABLE + +if _SCIPI_AVAILABLE: + import scipy.signal + + # back compatibility patch due to SMRMpy using scipy.signal.hamming + if not hasattr(scipy.signal, "hamming"): + scipy.signal.hamming = scipy.signal.windows.hamming __all__ = [ "cramers_v", diff --git a/src/torchmetrics/nominal/__init__.py b/src/torchmetrics/nominal/__init__.py index f23a7eb8c6b..b4b3625e30e 100644 --- a/src/torchmetrics/nominal/__init__.py +++ b/src/torchmetrics/nominal/__init__.py @@ -11,11 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from torchmetrics.nominal.cramers import CramersV from torchmetrics.nominal.fleiss_kappa import FleissKappa from torchmetrics.nominal.pearson import PearsonsContingencyCoefficient from torchmetrics.nominal.theils_u import TheilsU from torchmetrics.nominal.tschuprows import TschuprowsT +from torchmetrics.utilities.imports import _SCIPI_AVAILABLE + +if _SCIPI_AVAILABLE: + import scipy.signal + + # back compatibility patch due to SMRMpy using scipy.signal.hamming + if not hasattr(scipy.signal, "hamming"): + scipy.signal.hamming = scipy.signal.windows.hamming __all__ = [ "CramersV", diff --git a/src/torchmetrics/utilities/imports.py b/src/torchmetrics/utilities/imports.py index b40a334558f..10affebf579 100644 --- a/src/torchmetrics/utilities/imports.py +++ b/src/torchmetrics/utilities/imports.py @@ -64,6 +64,7 @@ _MECAB_KO_DIC_AVAILABLE = RequirementCache("mecab_ko_dic") _IPADIC_AVAILABLE = RequirementCache("ipadic") _SENTENCEPIECE_AVAILABLE = RequirementCache("sentencepiece") +_SCIPI_AVAILABLE = RequirementCache("scipy") _SKLEARN_GREATER_EQUAL_1_3 = RequirementCache("scikit-learn>=1.3.0") _LATEX_AVAILABLE: bool = shutil.which("latex") is not None diff --git a/tests/unittests/classification/test_stat_scores.py b/tests/unittests/classification/test_stat_scores.py index 53fa78d0368..5ea4c206bc0 100644 --- a/tests/unittests/classification/test_stat_scores.py +++ b/tests/unittests/classification/test_stat_scores.py @@ -582,8 +582,8 @@ def test_support_for_int(): """See issue: https://github.com/Lightning-AI/torchmetrics/issues/1970.""" seed_all(42) metric = MulticlassStatScores(num_classes=4, average="none", multidim_average="samplewise", ignore_index=0) - prediction = torch.randint(low=0, high=4, size=(1, 224, 224)).to(torch.uint8) - label = torch.randint(low=0, high=4, size=(1, 224, 224)).to(torch.uint8) + prediction = torch.randint(low=0, high=4, size=(1, 50, 50)).to(torch.uint8) + label = torch.randint(low=0, high=4, size=(1, 50, 50)).to(torch.uint8) score = metric(preds=prediction, target=label) assert score.shape == (1, 4, 5)