From e08e1c170d25946c339cccabfaab1bf2065bae59 Mon Sep 17 00:00:00 2001 From: Shweta Jacob Date: Sun, 8 Jan 2023 15:24:11 -0500 Subject: [PATCH 01/45] Add plot functionality to PerceptualEvaluationSpeechQuality --- src/torchmetrics/audio/pesq.py | 51 +++++++++++++++++++++++++- tests/unittests/utilities/test_plot.py | 7 ++++ 2 files changed, 57 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/audio/pesq.py b/src/torchmetrics/audio/pesq.py index 4a18e770e11..defd80c4a6c 100644 --- a/src/torchmetrics/audio/pesq.py +++ b/src/torchmetrics/audio/pesq.py @@ -11,13 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Optional, Union, Sequence from torch import Tensor, tensor from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality from torchmetrics.metric import Metric from torchmetrics.utilities.imports import _PESQ_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val __doctest_requires__ = {"PerceptualEvaluationSpeechQuality": ["pesq"]} @@ -119,3 +120,51 @@ def update(self, preds: Tensor, target: Tensor) -> None: def compute(self) -> Tensor: """Computes metric.""" return self.sum_pesq / self.total + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + fig: Figure object + ax: Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + Examples: + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics.audio import PerceptualEvaluationSpeechQuality + >>> metric = PerceptualEvaluationSpeechQuality() + >>> metric.update(torch.rand(10), torch.randint(2,(10,))) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.audio import PerceptualEvaluationSpeechQuality + >>> metric = PerceptualEvaluationSpeechQuality() + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(torch.rand(10), torch.randint(2,(10,)))) + >>> fig_, ax_ = metric.plot(values) + """ + val = val or self.compute() + fig, ax = plot_single_or_multi_val( + val, ax=ax, higher_is_better=self.higher_is_better, **self.plot_options, name=self.__class__.__name__ + ) + return fig, ax diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 6c09b082fd2..8ca1cb46b12 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -20,6 +20,7 @@ import torch from torchmetrics.functional.classification.accuracy import binary_accuracy, multiclass_accuracy +from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality from torchmetrics.functional.classification.confusion_matrix import ( binary_confusion_matrix, multiclass_confusion_matrix, @@ -49,6 +50,12 @@ lambda: torch.randint(3, (100,)), id="multiclass and average=None", ), + pytest.param( + partial(perceptual_evaluation_speech_quality), + lambda: torch.randint(3, (100,)), + lambda: torch.randint(3, (100,)), + id="perceptual_evaluation_speech_quality", + ), ], ) @pytest.mark.parametrize("num_vals", [1, 5, 10]) From 5d301e8573572019a28c215400d258fdc4f2e221 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 8 Jan 2023 20:25:12 +0000 Subject: [PATCH 02/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/audio/pesq.py | 2 +- tests/unittests/utilities/test_plot.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/audio/pesq.py b/src/torchmetrics/audio/pesq.py index defd80c4a6c..96e38fdfccf 100644 --- a/src/torchmetrics/audio/pesq.py +++ b/src/torchmetrics/audio/pesq.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Union, Sequence +from typing import Any, Optional, Sequence, Union from torch import Tensor, tensor diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 8ca1cb46b12..e43763743d7 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -19,8 +19,8 @@ import pytest import torch -from torchmetrics.functional.classification.accuracy import binary_accuracy, multiclass_accuracy from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality +from torchmetrics.functional.classification.accuracy import binary_accuracy, multiclass_accuracy from torchmetrics.functional.classification.confusion_matrix import ( binary_confusion_matrix, multiclass_confusion_matrix, From 3ff8372c5312cdf00ed2014bd3844d88a0ded8c7 Mon Sep 17 00:00:00 2001 From: Shweta Jacob Date: Sun, 8 Jan 2023 15:35:19 -0500 Subject: [PATCH 03/45] Update CHANGELOG --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index aa2202d04c9..3a146be32a3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added support for plotting of metrics through `.plot()` method ([#1328](https://github.com/Lightning-AI/metrics/pull/1328)) +- Added support for plotting of audio metrics through `.plot()` method ([#1434](https://github.com/Lightning-AI/metrics/pull/1434)) ### Changed From c785a94112bb0f266585f01b1b78ffb7e0e357a3 Mon Sep 17 00:00:00 2001 From: Shweta Jacob Date: Sun, 8 Jan 2023 15:53:42 -0500 Subject: [PATCH 04/45] Remove test for PerceptualEvaluationSpeechQuality plot --- tests/unittests/utilities/test_plot.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index e43763743d7..51e2539776d 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -19,7 +19,6 @@ import pytest import torch -from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality from torchmetrics.functional.classification.accuracy import binary_accuracy, multiclass_accuracy from torchmetrics.functional.classification.confusion_matrix import ( binary_confusion_matrix, @@ -49,13 +48,7 @@ lambda: torch.randint(3, (100,)), lambda: torch.randint(3, (100,)), id="multiclass and average=None", - ), - pytest.param( - partial(perceptual_evaluation_speech_quality), - lambda: torch.randint(3, (100,)), - lambda: torch.randint(3, (100,)), - id="perceptual_evaluation_speech_quality", - ), + ) ], ) @pytest.mark.parametrize("num_vals", [1, 5, 10]) From c63e87ccc0887ea78f34c339cb5b89027905692e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 8 Jan 2023 20:54:21 +0000 Subject: [PATCH 05/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/unittests/utilities/test_plot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 51e2539776d..6c09b082fd2 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -48,7 +48,7 @@ lambda: torch.randint(3, (100,)), lambda: torch.randint(3, (100,)), id="multiclass and average=None", - ) + ), ], ) @pytest.mark.parametrize("num_vals", [1, 5, 10]) From 8c79abaa3af0a8049f0f5a078a87d2c8bd9666be Mon Sep 17 00:00:00 2001 From: Shweta Jacob Date: Sun, 8 Jan 2023 15:57:58 -0500 Subject: [PATCH 06/45] Skip doctest if matplotlib is not available --- src/torchmetrics/audio/pesq.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/torchmetrics/audio/pesq.py b/src/torchmetrics/audio/pesq.py index 96e38fdfccf..b840c4fddea 100644 --- a/src/torchmetrics/audio/pesq.py +++ b/src/torchmetrics/audio/pesq.py @@ -18,10 +18,14 @@ from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality from torchmetrics.metric import Metric from torchmetrics.utilities.imports import _PESQ_AVAILABLE +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val __doctest_requires__ = {"PerceptualEvaluationSpeechQuality": ["pesq"]} +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["PerceptualEvaluationSpeechQuality.plot"] + class PerceptualEvaluationSpeechQuality(Metric): """Calculates `Perceptual Evaluation of Speech Quality`_ (PESQ). It's a recognized industry standard for audio From 4b51d1344b1dbb282973c8e8128cf722d87a5a21 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 8 Jan 2023 20:58:51 +0000 Subject: [PATCH 07/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/audio/pesq.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/torchmetrics/audio/pesq.py b/src/torchmetrics/audio/pesq.py index b840c4fddea..ede503e6f76 100644 --- a/src/torchmetrics/audio/pesq.py +++ b/src/torchmetrics/audio/pesq.py @@ -17,8 +17,7 @@ from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality from torchmetrics.metric import Metric -from torchmetrics.utilities.imports import _PESQ_AVAILABLE -from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _PESQ_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val __doctest_requires__ = {"PerceptualEvaluationSpeechQuality": ["pesq"]} From 428c4ba1982a9b0205d37c7ca4aa7a58c89561de Mon Sep 17 00:00:00 2001 From: Shweta Jacob Date: Sun, 8 Jan 2023 16:38:35 -0500 Subject: [PATCH 08/45] Update pesq plot docs --- src/torchmetrics/audio/pesq.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/audio/pesq.py b/src/torchmetrics/audio/pesq.py index ede503e6f76..dfd5e45d72e 100644 --- a/src/torchmetrics/audio/pesq.py +++ b/src/torchmetrics/audio/pesq.py @@ -84,6 +84,7 @@ class PerceptualEvaluationSpeechQuality(Metric): full_state_update: bool = False is_differentiable: bool = False higher_is_better: bool = True + plot_options: dict = {"lower_bound": 0.0, "upper_bound": 1.0} def __init__( self, @@ -151,7 +152,7 @@ def plot( >>> import torch >>> from torchmetrics.audio import PerceptualEvaluationSpeechQuality >>> metric = PerceptualEvaluationSpeechQuality() - >>> metric.update(torch.rand(10), torch.randint(2,(10,))) + >>> metric.update(torch.rand(10), torch.rand(10)) >>> fig_, ax_ = metric.plot() .. plot:: @@ -163,7 +164,7 @@ def plot( >>> metric = PerceptualEvaluationSpeechQuality() >>> values = [ ] >>> for _ in range(10): - ... values.append(metric(torch.rand(10), torch.randint(2,(10,)))) + ... values.append(metric(torch.rand(10), torch.rand(10))) >>> fig_, ax_ = metric.plot(values) """ val = val or self.compute() From bd7337807319f15dc4acbf9c8114d4c63e0f3d4d Mon Sep 17 00:00:00 2001 From: Shweta Jacob Date: Sun, 8 Jan 2023 16:48:22 -0500 Subject: [PATCH 09/45] Add test for pesq plotting --- tests/unittests/utilities/test_plot.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 6c09b082fd2..6ffc6038c13 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -20,6 +20,7 @@ import torch from torchmetrics.functional.classification.accuracy import binary_accuracy, multiclass_accuracy +from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality from torchmetrics.functional.classification.confusion_matrix import ( binary_confusion_matrix, multiclass_confusion_matrix, @@ -49,6 +50,12 @@ lambda: torch.randint(3, (100,)), id="multiclass and average=None", ), + pytest.param( + partial(perceptual_evaluation_speech_quality, fs=8000, mode='nb'), + lambda: torch.randn(8000), + lambda: torch.randn(8000), + id="perceptual_evaluation_speech_quality", + ), ], ) @pytest.mark.parametrize("num_vals", [1, 5, 10]) From 88b2f2b4cfaf53cfe762b1f6c8c8b69d79fe2c96 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 8 Jan 2023 21:49:00 +0000 Subject: [PATCH 10/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/unittests/utilities/test_plot.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 6ffc6038c13..effb305c1ac 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -19,8 +19,8 @@ import pytest import torch -from torchmetrics.functional.classification.accuracy import binary_accuracy, multiclass_accuracy from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality +from torchmetrics.functional.classification.accuracy import binary_accuracy, multiclass_accuracy from torchmetrics.functional.classification.confusion_matrix import ( binary_confusion_matrix, multiclass_confusion_matrix, @@ -51,7 +51,7 @@ id="multiclass and average=None", ), pytest.param( - partial(perceptual_evaluation_speech_quality, fs=8000, mode='nb'), + partial(perceptual_evaluation_speech_quality, fs=8000, mode="nb"), lambda: torch.randn(8000), lambda: torch.randn(8000), id="perceptual_evaluation_speech_quality", From 006797ace6d3c7059ad2b5de0f32be647b1edf69 Mon Sep 17 00:00:00 2001 From: Shweta Jacob Date: Sun, 8 Jan 2023 17:01:27 -0500 Subject: [PATCH 11/45] Update docs --- src/torchmetrics/audio/pesq.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/audio/pesq.py b/src/torchmetrics/audio/pesq.py index dfd5e45d72e..bbd68d5ac3a 100644 --- a/src/torchmetrics/audio/pesq.py +++ b/src/torchmetrics/audio/pesq.py @@ -151,8 +151,8 @@ def plot( >>> # Example plotting a single value >>> import torch >>> from torchmetrics.audio import PerceptualEvaluationSpeechQuality - >>> metric = PerceptualEvaluationSpeechQuality() - >>> metric.update(torch.rand(10), torch.rand(10)) + >>> metric = PerceptualEvaluationSpeechQuality(8000, 'nb') + >>> metric.update(torch.rand(8000), torch.rand(8000)) >>> fig_, ax_ = metric.plot() .. plot:: From 38a86c1437f540b33698843df3de38097f3f40b6 Mon Sep 17 00:00:00 2001 From: Shweta Jacob Date: Sun, 8 Jan 2023 17:14:22 -0500 Subject: [PATCH 12/45] Update docs --- src/torchmetrics/audio/pesq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/audio/pesq.py b/src/torchmetrics/audio/pesq.py index bbd68d5ac3a..72b6cc9421e 100644 --- a/src/torchmetrics/audio/pesq.py +++ b/src/torchmetrics/audio/pesq.py @@ -161,7 +161,7 @@ def plot( >>> # Example plotting multiple values >>> import torch >>> from torchmetrics.audio import PerceptualEvaluationSpeechQuality - >>> metric = PerceptualEvaluationSpeechQuality() + >>> metric = PerceptualEvaluationSpeechQuality(8000, 'nb') >>> values = [ ] >>> for _ in range(10): ... values.append(metric(torch.rand(10), torch.rand(10))) From f4eca1552272547ba0add43825a3a48d758793d9 Mon Sep 17 00:00:00 2001 From: Shweta Jacob Date: Sun, 8 Jan 2023 19:07:19 -0500 Subject: [PATCH 13/45] Add pesq plotting example --- examples/plotting.py | 20 ++++++++++++++++++++ src/torchmetrics/audio/pesq.py | 2 +- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/examples/plotting.py b/examples/plotting.py index f877e57c721..29c4460a7c3 100644 --- a/examples/plotting.py +++ b/examples/plotting.py @@ -17,6 +17,25 @@ import torch +def pesq_example(): + from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality + + p = lambda: torch.randn(8000) + t = lambda: torch.randn(8000) + + # plot single value + metric = PerceptualEvaluationSpeechQuality(8000, 'nb') + metric.update(p(), t()) + fig, ax = metric.plot() + + # plot multiple values + metric = PerceptualEvaluationSpeechQuality(16000, 'wb') + vals = [metric(p(), t()) for _ in range(10)] + fig, ax = metric.plot(vals) + + return fig, ax + + def accuracy_example(): from torchmetrics.classification import MulticlassAccuracy @@ -82,6 +101,7 @@ def confusion_matrix_example(): metrics_func = { "accuracy": accuracy_example, + "pesq": pesq_example, "mean_squared_error": mean_squared_error_example, "confusion_matrix": confusion_matrix_example, } diff --git a/src/torchmetrics/audio/pesq.py b/src/torchmetrics/audio/pesq.py index 72b6cc9421e..9735f0761e9 100644 --- a/src/torchmetrics/audio/pesq.py +++ b/src/torchmetrics/audio/pesq.py @@ -84,7 +84,7 @@ class PerceptualEvaluationSpeechQuality(Metric): full_state_update: bool = False is_differentiable: bool = False higher_is_better: bool = True - plot_options: dict = {"lower_bound": 0.0, "upper_bound": 1.0} + plot_options: dict = {"lower_bound": 1.0, "upper_bound": 4.5} def __init__( self, From 35b9adbf7db888dfa1c93a18ce59f0e0a8443a87 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Jan 2023 00:07:57 +0000 Subject: [PATCH 14/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/plotting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/plotting.py b/examples/plotting.py index 29c4460a7c3..fcfb28da8eb 100644 --- a/examples/plotting.py +++ b/examples/plotting.py @@ -24,12 +24,12 @@ def pesq_example(): t = lambda: torch.randn(8000) # plot single value - metric = PerceptualEvaluationSpeechQuality(8000, 'nb') + metric = PerceptualEvaluationSpeechQuality(8000, "nb") metric.update(p(), t()) fig, ax = metric.plot() # plot multiple values - metric = PerceptualEvaluationSpeechQuality(16000, 'wb') + metric = PerceptualEvaluationSpeechQuality(16000, "wb") vals = [metric(p(), t()) for _ in range(10)] fig, ax = metric.plot(vals) From 2effd880323ced4683069bafd8fdb253cd703483 Mon Sep 17 00:00:00 2001 From: Shweta Jacob Date: Sun, 8 Jan 2023 19:15:28 -0500 Subject: [PATCH 15/45] Add pesq to init file --- src/torchmetrics/audio/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/torchmetrics/audio/__init__.py b/src/torchmetrics/audio/__init__.py index 152228ac36a..4ed2680c0b3 100644 --- a/src/torchmetrics/audio/__init__.py +++ b/src/torchmetrics/audio/__init__.py @@ -21,3 +21,7 @@ if _PYSTOI_AVAILABLE: from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility # noqa: F401 + +__all__ = [ + "PerceptualEvaluationSpeechQuality" +] \ No newline at end of file From e8e505d9bd22ca8e9ccd0ffef613a3f603fd849a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Jan 2023 00:16:12 +0000 Subject: [PATCH 16/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/audio/__init__.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/torchmetrics/audio/__init__.py b/src/torchmetrics/audio/__init__.py index 4ed2680c0b3..4f8d52e491e 100644 --- a/src/torchmetrics/audio/__init__.py +++ b/src/torchmetrics/audio/__init__.py @@ -17,11 +17,9 @@ from torchmetrics.utilities.imports import _PESQ_AVAILABLE, _PYSTOI_AVAILABLE if _PESQ_AVAILABLE: - from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality # noqa: F401 + from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality if _PYSTOI_AVAILABLE: from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility # noqa: F401 -__all__ = [ - "PerceptualEvaluationSpeechQuality" -] \ No newline at end of file +__all__ = ["PerceptualEvaluationSpeechQuality"] From ffd4d7dce09d26a6b5f1c7433ca2b4f4faa8f2d1 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 3 Feb 2023 13:35:20 +0100 Subject: [PATCH 17/45] fix --- CHANGELOG.md | 2 ++ src/torchmetrics/audio/__init__.py | 4 +--- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 646704c9666..f88c9c97cd7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added support for plotting of metrics through `.plot()` method ([#1328](https://github.com/Lightning-AI/metrics/pull/1328)) + + - Added support for plotting of audio metrics through `.plot()` method ([#1434](https://github.com/Lightning-AI/metrics/pull/1434)) diff --git a/src/torchmetrics/audio/__init__.py b/src/torchmetrics/audio/__init__.py index 4f8d52e491e..152228ac36a 100644 --- a/src/torchmetrics/audio/__init__.py +++ b/src/torchmetrics/audio/__init__.py @@ -17,9 +17,7 @@ from torchmetrics.utilities.imports import _PESQ_AVAILABLE, _PYSTOI_AVAILABLE if _PESQ_AVAILABLE: - from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality + from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality # noqa: F401 if _PYSTOI_AVAILABLE: from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility # noqa: F401 - -__all__ = ["PerceptualEvaluationSpeechQuality"] From d15501ec161935f0d6fc03d32469c256c107ed8f Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 3 Feb 2023 13:39:48 +0100 Subject: [PATCH 18/45] add requirement --- requirements/docs.txt | 1 + src/torchmetrics/audio/pesq.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/requirements/docs.txt b/requirements/docs.txt index 24e5c11cd05..b034c7008f9 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -14,3 +14,4 @@ sphinx-copybutton>=0.3 # integrations -r integrate.txt -r visual.txt +-r audio.txt diff --git a/src/torchmetrics/audio/pesq.py b/src/torchmetrics/audio/pesq.py index 63fe5d5963e..dac1c5e8817 100644 --- a/src/torchmetrics/audio/pesq.py +++ b/src/torchmetrics/audio/pesq.py @@ -150,7 +150,7 @@ def plot( >>> # Example plotting a single value >>> import torch - >>> from torchmetrics.audio import PerceptualEvaluationSpeechQuality + >>> from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality >>> metric = PerceptualEvaluationSpeechQuality(8000, 'nb') >>> metric.update(torch.rand(8000), torch.rand(8000)) >>> fig_, ax_ = metric.plot() @@ -160,7 +160,7 @@ def plot( >>> # Example plotting multiple values >>> import torch - >>> from torchmetrics.audio import PerceptualEvaluationSpeechQuality + >>> from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality >>> metric = PerceptualEvaluationSpeechQuality(8000, 'nb') >>> values = [ ] >>> for _ in range(10): From 9463a6d5bd378b49e2e4e379572c1f454c3f7476 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 3 Feb 2023 13:56:33 +0100 Subject: [PATCH 19/45] fix doctest --- src/torchmetrics/audio/pesq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/audio/pesq.py b/src/torchmetrics/audio/pesq.py index dac1c5e8817..b782cec0c9e 100644 --- a/src/torchmetrics/audio/pesq.py +++ b/src/torchmetrics/audio/pesq.py @@ -164,7 +164,7 @@ def plot( >>> metric = PerceptualEvaluationSpeechQuality(8000, 'nb') >>> values = [ ] >>> for _ in range(10): - ... values.append(metric(torch.rand(10), torch.rand(10))) + ... values.append(metric(torch.rand(8000), torch.rand(8000))) >>> fig_, ax_ = metric.plot(values) """ val = val or self.compute() From 3bec0238c02aa75cfde9c159829af3f22c7e04b0 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Fri, 3 Feb 2023 22:24:34 +0900 Subject: [PATCH 20/45] gh: update templates (#1477)Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Nicki Skafte Detlefsen * gh: update templates --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .github/ISSUE_TEMPLATE/bug_report.md | 5 ++++- .github/ISSUE_TEMPLATE/documentation.md | 5 ++--- .github/PULL_REQUEST_TEMPLATE.md | 12 +++++++++--- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 487b6e8bdc2..150f7849963 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -16,11 +16,14 @@ Steps to reproduce the behavior... -#### Code sample +
+ Code sample +
+ ### Expected behavior diff --git a/.github/ISSUE_TEMPLATE/documentation.md b/.github/ISSUE_TEMPLATE/documentation.md index c74b3408000..456d49be454 100644 --- a/.github/ISSUE_TEMPLATE/documentation.md +++ b/.github/ISSUE_TEMPLATE/documentation.md @@ -10,8 +10,7 @@ assignees: '' For typos and doc fixes, please go ahead and: -1. Create an issue. -1. Fix the typo. -1. Submit a PR. +- For a simple typo or fix, please send directly a PR (no need to create an issue) +- If you are not sure about the proper solution, please describe here your finding... Thanks! diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 0c6881d0228..318ee5483eb 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -2,18 +2,24 @@ Fixes #\ -## Before submitting +
+ Before submitting -- [ ] Was this **discussed/approved** via a Github issue? (no need for typos and docs improvements) +- [ ] Was this **discussed/agreed** via a Github issue? (no need for typos and docs improvements) - [ ] Did you read the [contributor guideline](https://github.com/Lightning-AI/metrics/blob/master/.github/CONTRIBUTING.md), Pull Request section? - [ ] Did you make sure to **update the docs**? - [ ] Did you write any new **necessary tests**? -## PR review +
+ +
+ PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. +
+ ## Did you have fun? Make sure you had fun coding 🙃 From 40af3a237fb85ca980cf71b716fab41fea55ae49 Mon Sep 17 00:00:00 2001 From: Shweta Jacob Date: Fri, 3 Feb 2023 13:46:15 -0500 Subject: [PATCH 21/45] Add plot function for PermutationInvariantTraining --- examples/plotting.py | 20 ++++++++++ src/torchmetrics/audio/pit.py | 53 +++++++++++++++++++++++++- tests/unittests/utilities/test_plot.py | 8 ++++ 3 files changed, 80 insertions(+), 1 deletion(-) diff --git a/examples/plotting.py b/examples/plotting.py index fcfb28da8eb..03e6a1ced8d 100644 --- a/examples/plotting.py +++ b/examples/plotting.py @@ -35,6 +35,25 @@ def pesq_example(): return fig, ax +def pit_example(): + from torchmetrics.audio.pit import PermutationInvariantTraining + from torchmetrics.functional import scale_invariant_signal_noise_ratio + + p = lambda: torch.randn(3, 2, 5) + t = lambda: torch.randn(3, 2, 5) + + # plot single value + metric = PermutationInvariantTraining(scale_invariant_signal_noise_ratio, 'max') + metric.update(p(), t()) + fig, ax = metric.plot() + + # plot multiple values + metric = PermutationInvariantTraining(scale_invariant_signal_noise_ratio, 'max') + vals = [metric(p(), t()) for _ in range(10)] + fig, ax = metric.plot(vals) + + return fig, ax + def accuracy_example(): from torchmetrics.classification import MulticlassAccuracy @@ -102,6 +121,7 @@ def confusion_matrix_example(): metrics_func = { "accuracy": accuracy_example, "pesq": pesq_example, + "pit": pit_example, "mean_squared_error": mean_squared_error_example, "confusion_matrix": confusion_matrix_example, } diff --git a/src/torchmetrics/audio/pit.py b/src/torchmetrics/audio/pit.py index e68f74797e3..ba71aab2ecf 100644 --- a/src/torchmetrics/audio/pit.py +++ b/src/torchmetrics/audio/pit.py @@ -11,14 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict +from typing import Any, Callable, Dict, Optional, Sequence, Union from torch import Tensor, tensor from typing_extensions import Literal from torchmetrics.functional.audio.pit import permutation_invariant_training from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["PermutationInvariantTraining.plot"] class PermutationInvariantTraining(Metric): """Calculates `Permutation invariant training`_ (PIT) that can evaluate models for speaker independent multi- @@ -60,6 +64,7 @@ class PermutationInvariantTraining(Metric): is_differentiable: bool = True sum_pit_metric: Tensor total: Tensor + plot_options: dict = {"lower_bound": -10.0, "upper_bound": 1.0} def __init__( self, @@ -90,3 +95,49 @@ def update(self, preds: Tensor, target: Tensor) -> None: def compute(self) -> Tensor: """Computes metric.""" return self.sum_pit_metric / self.total + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + Returns: + fig: Figure object + ax: Axes object + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + Examples: + .. plot:: + :scale: 75 + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics.audio.pit import PermutationInvariantTraining + >>> from torchmetrics.functional import scale_invariant_signal_noise_ratio + >>> preds = torch.randn(3, 2, 5) # [batch, spk, time] + >>> target = torch.randn(3, 2, 5) # [batch, spk, time] + >>> pit = PermutationInvariantTraining(scale_invariant_signal_noise_ratio, 'max') + >>> metric = pit(preds, target) + >>> fig_, ax_ = metric.plot() + .. plot:: + :scale: 75 + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality + >>> from torchmetrics.functional import scale_invariant_signal_noise_ratio + >>> preds = torch.randn(3, 2, 5) # [batch, spk, time] + >>> target = torch.randn(3, 2, 5) # [batch, spk, time] + >>> pit = PermutationInvariantTraining(scale_invariant_signal_noise_ratio, 'max') + >>> values = [ ] + >>> for _ in range(10): + ... values.append(pit(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + val = val or self.compute() + fig, ax = plot_single_or_multi_val( + val, ax=ax, higher_is_better=self.higher_is_better, **self.plot_options, name=self.__class__.__name__ + ) + return fig, ax diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index effb305c1ac..6fb3612aab3 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -20,6 +20,8 @@ import torch from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality +from torchmetrics.functional.audio.pit import permutation_invariant_training +from torchmetrics.functional import scale_invariant_signal_noise_ratio from torchmetrics.functional.classification.accuracy import binary_accuracy, multiclass_accuracy from torchmetrics.functional.classification.confusion_matrix import ( binary_confusion_matrix, @@ -56,6 +58,12 @@ lambda: torch.randn(8000), id="perceptual_evaluation_speech_quality", ), + pytest.param( + partial(permutation_invariant_training, scale_invariant_signal_noise_ratio, 'max'), + lambda: torch.randn(3, 2, 5), + lambda: torch.randn(3, 2, 5), + id="permutation_invariant_training", + ), ], ) @pytest.mark.parametrize("num_vals", [1, 5, 10]) From 01ddadc6bae9d240f7fca4e7dd543e346d38fa5c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 3 Feb 2023 18:47:01 +0000 Subject: [PATCH 22/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/plotting.py | 5 +++-- src/torchmetrics/audio/pit.py | 3 ++- tests/unittests/utilities/test_plot.py | 4 ++-- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/examples/plotting.py b/examples/plotting.py index 03e6a1ced8d..c20fa55a99b 100644 --- a/examples/plotting.py +++ b/examples/plotting.py @@ -35,6 +35,7 @@ def pesq_example(): return fig, ax + def pit_example(): from torchmetrics.audio.pit import PermutationInvariantTraining from torchmetrics.functional import scale_invariant_signal_noise_ratio @@ -43,12 +44,12 @@ def pit_example(): t = lambda: torch.randn(3, 2, 5) # plot single value - metric = PermutationInvariantTraining(scale_invariant_signal_noise_ratio, 'max') + metric = PermutationInvariantTraining(scale_invariant_signal_noise_ratio, "max") metric.update(p(), t()) fig, ax = metric.plot() # plot multiple values - metric = PermutationInvariantTraining(scale_invariant_signal_noise_ratio, 'max') + metric = PermutationInvariantTraining(scale_invariant_signal_noise_ratio, "max") vals = [metric(p(), t()) for _ in range(10)] fig, ax = metric.plot(vals) diff --git a/src/torchmetrics/audio/pit.py b/src/torchmetrics/audio/pit.py index ba71aab2ecf..a262fff49c6 100644 --- a/src/torchmetrics/audio/pit.py +++ b/src/torchmetrics/audio/pit.py @@ -24,6 +24,7 @@ if not _MATPLOTLIB_AVAILABLE: __doctest_skip__ = ["PermutationInvariantTraining.plot"] + class PermutationInvariantTraining(Metric): """Calculates `Permutation invariant training`_ (PIT) that can evaluate models for speaker independent multi- talker speech separation in a permutation invariant way. @@ -97,7 +98,7 @@ def compute(self) -> Tensor: return self.sum_pit_metric / self.total def plot( - self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. Args: diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 6fb3612aab3..ea6ae9259d4 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -19,9 +19,9 @@ import pytest import torch +from torchmetrics.functional import scale_invariant_signal_noise_ratio from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality from torchmetrics.functional.audio.pit import permutation_invariant_training -from torchmetrics.functional import scale_invariant_signal_noise_ratio from torchmetrics.functional.classification.accuracy import binary_accuracy, multiclass_accuracy from torchmetrics.functional.classification.confusion_matrix import ( binary_confusion_matrix, @@ -59,7 +59,7 @@ id="perceptual_evaluation_speech_quality", ), pytest.param( - partial(permutation_invariant_training, scale_invariant_signal_noise_ratio, 'max'), + partial(permutation_invariant_training, scale_invariant_signal_noise_ratio, "max"), lambda: torch.randn(3, 2, 5), lambda: torch.randn(3, 2, 5), id="permutation_invariant_training", From 6b81ad704aafb20621b00b6ae165aacdcb7ff116 Mon Sep 17 00:00:00 2001 From: Shweta Jacob Date: Sat, 4 Feb 2023 14:34:55 -0500 Subject: [PATCH 23/45] Update docstring for plot in PermutationInvariantTraining --- src/torchmetrics/audio/pit.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/audio/pit.py b/src/torchmetrics/audio/pit.py index a262fff49c6..38aad646f98 100644 --- a/src/torchmetrics/audio/pit.py +++ b/src/torchmetrics/audio/pit.py @@ -101,19 +101,25 @@ def plot( self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. + Args: val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. If no value is provided, will automatically call `metric.compute` and plot that result. ax: An matplotlib axis object. If provided will add plot to that axis + Returns: fig: Figure object ax: Axes object + Raises: ModuleNotFoundError: If `matplotlib` is not installed + Examples: + .. plot:: :scale: 75 + >>> # Example plotting a single value >>> import torch >>> from torchmetrics.audio.pit import PermutationInvariantTraining @@ -123,11 +129,13 @@ def plot( >>> pit = PermutationInvariantTraining(scale_invariant_signal_noise_ratio, 'max') >>> metric = pit(preds, target) >>> fig_, ax_ = metric.plot() + .. plot:: :scale: 75 + >>> # Example plotting multiple values >>> import torch - >>> from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality + >>> from torchmetrics.audio.pit import PermutationInvariantTraining >>> from torchmetrics.functional import scale_invariant_signal_noise_ratio >>> preds = torch.randn(3, 2, 5) # [batch, spk, time] >>> target = torch.randn(3, 2, 5) # [batch, spk, time] From 32ef72b8dced922a296e892e012f5c7e8bd4a664 Mon Sep 17 00:00:00 2001 From: Shweta Jacob Date: Sat, 4 Feb 2023 14:45:46 -0500 Subject: [PATCH 24/45] Add doctest requirement for pit --- src/torchmetrics/audio/pit.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/torchmetrics/audio/pit.py b/src/torchmetrics/audio/pit.py index 38aad646f98..6a8622804de 100644 --- a/src/torchmetrics/audio/pit.py +++ b/src/torchmetrics/audio/pit.py @@ -21,6 +21,8 @@ from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val +__doctest_requires__ = {"PermutationInvariantTraining": ["pit"]} + if not _MATPLOTLIB_AVAILABLE: __doctest_skip__ = ["PermutationInvariantTraining.plot"] From 77664e85c5534d1b98057505850875bffbc7e8b8 Mon Sep 17 00:00:00 2001 From: Shweta Jacob Date: Sat, 4 Feb 2023 14:58:13 -0500 Subject: [PATCH 25/45] Fix docstring of pit plot --- src/torchmetrics/audio/pit.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/torchmetrics/audio/pit.py b/src/torchmetrics/audio/pit.py index 6a8622804de..796870ad400 100644 --- a/src/torchmetrics/audio/pit.py +++ b/src/torchmetrics/audio/pit.py @@ -128,8 +128,8 @@ def plot( >>> from torchmetrics.functional import scale_invariant_signal_noise_ratio >>> preds = torch.randn(3, 2, 5) # [batch, spk, time] >>> target = torch.randn(3, 2, 5) # [batch, spk, time] - >>> pit = PermutationInvariantTraining(scale_invariant_signal_noise_ratio, 'max') - >>> metric = pit(preds, target) + >>> metric = PermutationInvariantTraining(scale_invariant_signal_noise_ratio, 'max') + >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot() .. plot:: @@ -141,10 +141,10 @@ def plot( >>> from torchmetrics.functional import scale_invariant_signal_noise_ratio >>> preds = torch.randn(3, 2, 5) # [batch, spk, time] >>> target = torch.randn(3, 2, 5) # [batch, spk, time] - >>> pit = PermutationInvariantTraining(scale_invariant_signal_noise_ratio, 'max') + >>> metric = PermutationInvariantTraining(scale_invariant_signal_noise_ratio, 'max') >>> values = [ ] >>> for _ in range(10): - ... values.append(pit(preds, target)) + ... values.append(metric(preds, target)) >>> fig_, ax_ = metric.plot(values) """ val = val or self.compute() From c83d6290ac8d78b4ad77546cce5f3d13d1414b47 Mon Sep 17 00:00:00 2001 From: Shweta Jacob Date: Sat, 4 Feb 2023 17:15:54 -0500 Subject: [PATCH 26/45] Fix plot test for pit --- tests/unittests/utilities/test_plot.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index ea6ae9259d4..5bc639faba3 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -57,20 +57,35 @@ lambda: torch.randn(8000), lambda: torch.randn(8000), id="perceptual_evaluation_speech_quality", - ), + ) + ], +) +@pytest.mark.parametrize("num_vals", [1, 5, 10]) +def test_single_multi_val_plotter(metric, preds, target, num_vals): + vals = [] + for i in range(num_vals): + vals.append(metric(preds(), target())) + vals = vals[0] if i == 1 else vals + fig, ax = plot_single_or_multi_val(vals) + assert isinstance(fig, plt.Figure) + assert isinstance(ax, matplotlib.axes.Axes) + +@pytest.mark.parametrize( + "metric, preds, target", + [ pytest.param( - partial(permutation_invariant_training, scale_invariant_signal_noise_ratio, "max"), + partial(permutation_invariant_training, metric_func=scale_invariant_signal_noise_ratio, eval_func="max"), lambda: torch.randn(3, 2, 5), lambda: torch.randn(3, 2, 5), id="permutation_invariant_training", - ), + ) ], ) @pytest.mark.parametrize("num_vals", [1, 5, 10]) -def test_single_multi_val_plotter(metric, preds, target, num_vals): +def test_single_multi_val_plotter_pit(metric, preds, target, num_vals): vals = [] for i in range(num_vals): - vals.append(metric(preds(), target())) + vals.append(metric(preds(), target())[0]) vals = vals[0] if i == 1 else vals fig, ax = plot_single_or_multi_val(vals) assert isinstance(fig, plt.Figure) From 52af69c1a853cf36127b49a7b68a9b6a0c7cd10e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 4 Feb 2023 22:16:22 +0000 Subject: [PATCH 27/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/unittests/utilities/test_plot.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 5bc639faba3..d6b9e296730 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -57,7 +57,7 @@ lambda: torch.randn(8000), lambda: torch.randn(8000), id="perceptual_evaluation_speech_quality", - ) + ), ], ) @pytest.mark.parametrize("num_vals", [1, 5, 10]) @@ -70,6 +70,7 @@ def test_single_multi_val_plotter(metric, preds, target, num_vals): assert isinstance(fig, plt.Figure) assert isinstance(ax, matplotlib.axes.Axes) + @pytest.mark.parametrize( "metric, preds, target", [ From f96aca8ed397b6a2441b6961db0360b3b40dc6b2 Mon Sep 17 00:00:00 2001 From: Shweta Jacob Date: Sat, 4 Feb 2023 23:44:54 -0500 Subject: [PATCH 28/45] Add plot functionality for sdr audio metrics --- src/torchmetrics/audio/sdr.py | 111 ++++++++++++++++++++++++- tests/unittests/utilities/test_plot.py | 16 +++- 2 files changed, 124 insertions(+), 3 deletions(-) diff --git a/src/torchmetrics/audio/sdr.py b/src/torchmetrics/audio/sdr.py index efddeb42efd..024bb125d0e 100644 --- a/src/torchmetrics/audio/sdr.py +++ b/src/torchmetrics/audio/sdr.py @@ -11,15 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Any, Optional, Sequence, Union from torch import Tensor, tensor from torchmetrics.functional.audio.sdr import scale_invariant_signal_distortion_ratio, signal_distortion_ratio from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val __doctest_requires__ = {"SignalDistortionRatio": ["fast_bss_eval"]} +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["SignalDistortionRatio.plot", "ScaleInvariantSignalDistortionRatio.plot"] + class SignalDistortionRatio(Metric): r"""Calculates Signal to Distortion Ratio (SDR) metric. See `SDR ref1`_ and `SDR ref2`_ for details on the @@ -79,6 +84,7 @@ class SignalDistortionRatio(Metric): full_state_update: bool = False is_differentiable: bool = True higher_is_better: bool = True + plot_options: dict = {"lower_bound": 1.0, "upper_bound": 4.5} def __init__( self, @@ -111,6 +117,55 @@ def compute(self) -> Tensor: """Computes metric.""" return self.sum_sdr / self.total + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + fig: Figure object + ax: Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + Examples: + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics.audio.sdr import SignalDistortionRatio + >>> metric = SignalDistortionRatio() + >>> metric.update(torch.rand(8000), torch.rand(8000)) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.audio.sdr import SignalDistortionRatio + >>> metric = SignalDistortionRatio() + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(torch.rand(8000), torch.rand(8000))) + >>> fig_, ax_ = metric.plot(values) + """ + val = val or self.compute() + fig, ax = plot_single_or_multi_val( + val, ax=ax, higher_is_better=self.higher_is_better, **self.plot_options, name=self.__class__.__name__ + ) + return fig, ax + + class ScaleInvariantSignalDistortionRatio(Metric): """`Scale-invariant signal-to-distortion ratio`_ (SI-SDR). The SI-SDR value is in general considered an overall @@ -147,6 +202,7 @@ class ScaleInvariantSignalDistortionRatio(Metric): higher_is_better = True sum_si_sdr: Tensor total: Tensor + plot_options: dict = {"lower_bound": 1.0, "upper_bound": 4.5} def __init__( self, @@ -169,3 +225,56 @@ def update(self, preds: Tensor, target: Tensor) -> None: def compute(self) -> Tensor: """Computes metric.""" return self.sum_si_sdr / self.total + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + fig: Figure object + ax: Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + Examples: + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics.audio.sdr import ScaleInvariantSignalDistortionRatio + >>> target = torch.randn(5) + >>> preds = torch.randn(5) + >>> metric = ScaleInvariantSignalDistortionRatio() + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.audio.sdr import ScaleInvariantSignalDistortionRatio + >>> target = torch.randn(5) + >>> preds = torch.randn(5) + >>> metric = ScaleInvariantSignalDistortionRatio() + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + val = val or self.compute() + fig, ax = plot_single_or_multi_val( + val, ax=ax, higher_is_better=self.higher_is_better, **self.plot_options, name=self.__class__.__name__ + ) + return fig, ax + diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index d6b9e296730..92a2d99181f 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -19,7 +19,8 @@ import pytest import torch -from torchmetrics.functional import scale_invariant_signal_noise_ratio +from torchmetrics.functional import scale_invariant_signal_noise_ratio, signal_distortion_ratio, \ + scale_invariant_signal_distortion_ratio from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality from torchmetrics.functional.audio.pit import permutation_invariant_training from torchmetrics.functional.classification.accuracy import binary_accuracy, multiclass_accuracy @@ -58,6 +59,18 @@ lambda: torch.randn(8000), id="perceptual_evaluation_speech_quality", ), + pytest.param( + partial(signal_distortion_ratio), + lambda: torch.randn(8000), + lambda: torch.randn(8000), + id="signal_distortion_ratio", + ), + pytest.param( + partial(scale_invariant_signal_distortion_ratio), + lambda: torch.randn(5), + lambda: torch.randn(5), + id="scale_invariant_signal_distortion_ratio", + ) ], ) @pytest.mark.parametrize("num_vals", [1, 5, 10]) @@ -70,7 +83,6 @@ def test_single_multi_val_plotter(metric, preds, target, num_vals): assert isinstance(fig, plt.Figure) assert isinstance(ax, matplotlib.axes.Axes) - @pytest.mark.parametrize( "metric, preds, target", [ From 0540c50544f52b560bda672a328e1db05c3659b3 Mon Sep 17 00:00:00 2001 From: Shweta Jacob Date: Sat, 4 Feb 2023 23:56:12 -0500 Subject: [PATCH 29/45] Add plot examples for sdr audio metrics --- examples/plotting.py | 38 +++++++++++++++++++++++++++++++++++ src/torchmetrics/audio/sdr.py | 4 ++-- 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/examples/plotting.py b/examples/plotting.py index c20fa55a99b..a6b104120dd 100644 --- a/examples/plotting.py +++ b/examples/plotting.py @@ -55,6 +55,42 @@ def pit_example(): return fig, ax +def sdr_example(): + from torchmetrics.audio.sdr import SignalDistortionRatio + + p = lambda: torch.randn(8000) + t = lambda: torch.randn(8000) + + # plot single value + metric = SignalDistortionRatio() + metric.update(p(), t()) + fig, ax = metric.plot() + + # plot multiple values + metric = SignalDistortionRatio() + vals = [metric(p(), t()) for _ in range(10)] + fig, ax = metric.plot(vals) + + return fig, ax + +def si_sdr_example(): + from torchmetrics.audio.sdr import ScaleInvariantSignalDistortionRatio + + p = lambda: torch.randn(5) + t = lambda: torch.randn(5) + + # plot single value + metric = ScaleInvariantSignalDistortionRatio() + metric.update(p(), t()) + fig, ax = metric.plot() + + # plot multiple values + metric = ScaleInvariantSignalDistortionRatio() + vals = [metric(p(), t()) for _ in range(10)] + fig, ax = metric.plot(vals) + + return fig, ax + def accuracy_example(): from torchmetrics.classification import MulticlassAccuracy @@ -123,6 +159,8 @@ def confusion_matrix_example(): "accuracy": accuracy_example, "pesq": pesq_example, "pit": pit_example, + "sdr": sdr_example, + "si-sdr": si_sdr_example, "mean_squared_error": mean_squared_error_example, "confusion_matrix": confusion_matrix_example, } diff --git a/src/torchmetrics/audio/sdr.py b/src/torchmetrics/audio/sdr.py index 024bb125d0e..4636bf3dcae 100644 --- a/src/torchmetrics/audio/sdr.py +++ b/src/torchmetrics/audio/sdr.py @@ -84,7 +84,7 @@ class SignalDistortionRatio(Metric): full_state_update: bool = False is_differentiable: bool = True higher_is_better: bool = True - plot_options: dict = {"lower_bound": 1.0, "upper_bound": 4.5} + plot_options: dict = {"lower_bound": -20.0, "upper_bound": 1.0} def __init__( self, @@ -202,7 +202,7 @@ class ScaleInvariantSignalDistortionRatio(Metric): higher_is_better = True sum_si_sdr: Tensor total: Tensor - plot_options: dict = {"lower_bound": 1.0, "upper_bound": 4.5} + plot_options: dict = {"lower_bound": -40.0, "upper_bound": 20.0} def __init__( self, From 7df844d73a0e5fd0b4f5d429209053190e911666 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 5 Feb 2023 05:00:01 +0000 Subject: [PATCH 30/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/plotting.py | 2 ++ src/torchmetrics/audio/sdr.py | 2 -- tests/unittests/utilities/test_plot.py | 10 +++++++--- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/examples/plotting.py b/examples/plotting.py index a6b104120dd..31ffebf2b13 100644 --- a/examples/plotting.py +++ b/examples/plotting.py @@ -55,6 +55,7 @@ def pit_example(): return fig, ax + def sdr_example(): from torchmetrics.audio.sdr import SignalDistortionRatio @@ -73,6 +74,7 @@ def sdr_example(): return fig, ax + def si_sdr_example(): from torchmetrics.audio.sdr import ScaleInvariantSignalDistortionRatio diff --git a/src/torchmetrics/audio/sdr.py b/src/torchmetrics/audio/sdr.py index 4636bf3dcae..7eeb561fd19 100644 --- a/src/torchmetrics/audio/sdr.py +++ b/src/torchmetrics/audio/sdr.py @@ -166,7 +166,6 @@ def plot( return fig, ax - class ScaleInvariantSignalDistortionRatio(Metric): """`Scale-invariant signal-to-distortion ratio`_ (SI-SDR). The SI-SDR value is in general considered an overall measure of how good a source sound. @@ -277,4 +276,3 @@ def plot( val, ax=ax, higher_is_better=self.higher_is_better, **self.plot_options, name=self.__class__.__name__ ) return fig, ax - diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 92a2d99181f..ea939313a8c 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -19,8 +19,11 @@ import pytest import torch -from torchmetrics.functional import scale_invariant_signal_noise_ratio, signal_distortion_ratio, \ - scale_invariant_signal_distortion_ratio +from torchmetrics.functional import ( + scale_invariant_signal_distortion_ratio, + scale_invariant_signal_noise_ratio, + signal_distortion_ratio, +) from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality from torchmetrics.functional.audio.pit import permutation_invariant_training from torchmetrics.functional.classification.accuracy import binary_accuracy, multiclass_accuracy @@ -70,7 +73,7 @@ lambda: torch.randn(5), lambda: torch.randn(5), id="scale_invariant_signal_distortion_ratio", - ) + ), ], ) @pytest.mark.parametrize("num_vals", [1, 5, 10]) @@ -83,6 +86,7 @@ def test_single_multi_val_plotter(metric, preds, target, num_vals): assert isinstance(fig, plt.Figure) assert isinstance(ax, matplotlib.axes.Axes) + @pytest.mark.parametrize( "metric, preds, target", [ From 6b252f9dc17412753dd7159fb372344b73754088 Mon Sep 17 00:00:00 2001 From: Shweta Jacob Date: Sun, 5 Feb 2023 15:15:04 -0500 Subject: [PATCH 31/45] Add plot functionality for snr audio metrics --- src/torchmetrics/audio/snr.py | 104 ++++++++++++++++++++++++- tests/unittests/utilities/test_plot.py | 20 +++-- 2 files changed, 117 insertions(+), 7 deletions(-) diff --git a/src/torchmetrics/audio/snr.py b/src/torchmetrics/audio/snr.py index c7f6db3e7ee..c11e77bfcd7 100644 --- a/src/torchmetrics/audio/snr.py +++ b/src/torchmetrics/audio/snr.py @@ -11,13 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Optional, Sequence, Union from torch import Tensor, tensor from torchmetrics.functional.audio.snr import scale_invariant_signal_noise_ratio, signal_noise_ratio from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["SignalNoiseRatio.plot", "ScaleInvariantSignalNoiseRatio.plot"] class SignalNoiseRatio(Metric): r"""Calculates `Signal-to-noise ratio`_ (SNR_) meric for evaluating quality of audio. It is defined as: @@ -59,6 +63,7 @@ class SignalNoiseRatio(Metric): higher_is_better: bool = True sum_snr: Tensor total: Tensor + plot_options: dict = {"lower_bound": -20.0, "upper_bound": 1.0} def __init__( self, @@ -82,6 +87,54 @@ def compute(self) -> Tensor: """Computes metric.""" return self.sum_snr / self.total + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + fig: Figure object + ax: Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + Examples: + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics.audio.snr import SignalNoiseRatio + >>> metric = SignalNoiseRatio() + >>> metric.update(torch.rand(4), torch.rand(4)) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.audio.snr import SignalNoiseRatio + >>> metric = SignalNoiseRatio() + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(torch.rand(4), torch.rand(4))) + >>> fig_, ax_ = metric.plot(values) + """ + val = val or self.compute() + fig, ax = plot_single_or_multi_val( + val, ax=ax, higher_is_better=self.higher_is_better, **self.plot_options, name=self.__class__.__name__ + ) + return fig, ax + class ScaleInvariantSignalNoiseRatio(Metric): """Calculates `Scale-invariant signal-to-noise ratio`_ (SI-SNR) metric for evaluating quality of audio. @@ -116,6 +169,7 @@ class ScaleInvariantSignalNoiseRatio(Metric): sum_si_snr: Tensor total: Tensor higher_is_better = True + plot_options: dict = {"lower_bound": -20.0, "upper_bound": 1.0} def __init__( self, @@ -136,3 +190,51 @@ def update(self, preds: Tensor, target: Tensor) -> None: def compute(self) -> Tensor: """Computes metric.""" return self.sum_si_snr / self.total + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + fig: Figure object + ax: Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + Examples: + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics.audio.snr import ScaleInvariantSignalNoiseRatio + >>> metric = ScaleInvariantSignalNoiseRatio() + >>> metric.update(torch.rand(4), torch.rand(4)) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.audio.snr import ScaleInvariantSignalNoiseRatio + >>> metric = ScaleInvariantSignalNoiseRatio() + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(torch.rand(4), torch.rand(4))) + >>> fig_, ax_ = metric.plot(values) + """ + val = val or self.compute() + fig, ax = plot_single_or_multi_val( + val, ax=ax, higher_is_better=self.higher_is_better, **self.plot_options, name=self.__class__.__name__ + ) + return fig, ax diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index ea939313a8c..09b76073492 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -19,11 +19,8 @@ import pytest import torch -from torchmetrics.functional import ( - scale_invariant_signal_distortion_ratio, - scale_invariant_signal_noise_ratio, - signal_distortion_ratio, -) +from torchmetrics.functional import scale_invariant_signal_noise_ratio, signal_distortion_ratio, \ + scale_invariant_signal_distortion_ratio, signal_noise_ratio from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality from torchmetrics.functional.audio.pit import permutation_invariant_training from torchmetrics.functional.classification.accuracy import binary_accuracy, multiclass_accuracy @@ -74,6 +71,18 @@ lambda: torch.randn(5), id="scale_invariant_signal_distortion_ratio", ), + pytest.param( + partial(signal_noise_ratio), + lambda: torch.randn(4), + lambda: torch.randn(4), + id="signal_noise_ratio", + ), + pytest.param( + partial(scale_invariant_signal_noise_ratio), + lambda: torch.randn(4), + lambda: torch.randn(4), + id="scale_invariant_signal_noise_ratio", + ) ], ) @pytest.mark.parametrize("num_vals", [1, 5, 10]) @@ -86,7 +95,6 @@ def test_single_multi_val_plotter(metric, preds, target, num_vals): assert isinstance(fig, plt.Figure) assert isinstance(ax, matplotlib.axes.Axes) - @pytest.mark.parametrize( "metric, preds, target", [ From 3f7abc0006d8ada9c6076dd14dae17b8eb500d88 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 5 Feb 2023 20:17:19 +0000 Subject: [PATCH 32/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/audio/snr.py | 1 + tests/unittests/utilities/test_plot.py | 11 ++++++++--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/torchmetrics/audio/snr.py b/src/torchmetrics/audio/snr.py index c11e77bfcd7..57d88420585 100644 --- a/src/torchmetrics/audio/snr.py +++ b/src/torchmetrics/audio/snr.py @@ -23,6 +23,7 @@ if not _MATPLOTLIB_AVAILABLE: __doctest_skip__ = ["SignalNoiseRatio.plot", "ScaleInvariantSignalNoiseRatio.plot"] + class SignalNoiseRatio(Metric): r"""Calculates `Signal-to-noise ratio`_ (SNR_) meric for evaluating quality of audio. It is defined as: diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 09b76073492..28f0d3b3478 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -19,8 +19,12 @@ import pytest import torch -from torchmetrics.functional import scale_invariant_signal_noise_ratio, signal_distortion_ratio, \ - scale_invariant_signal_distortion_ratio, signal_noise_ratio +from torchmetrics.functional import ( + scale_invariant_signal_distortion_ratio, + scale_invariant_signal_noise_ratio, + signal_distortion_ratio, + signal_noise_ratio, +) from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality from torchmetrics.functional.audio.pit import permutation_invariant_training from torchmetrics.functional.classification.accuracy import binary_accuracy, multiclass_accuracy @@ -82,7 +86,7 @@ lambda: torch.randn(4), lambda: torch.randn(4), id="scale_invariant_signal_noise_ratio", - ) + ), ], ) @pytest.mark.parametrize("num_vals", [1, 5, 10]) @@ -95,6 +99,7 @@ def test_single_multi_val_plotter(metric, preds, target, num_vals): assert isinstance(fig, plt.Figure) assert isinstance(ax, matplotlib.axes.Axes) + @pytest.mark.parametrize( "metric, preds, target", [ From ec04abcd237dce1abc4f5475c6f07920b0b0bc8b Mon Sep 17 00:00:00 2001 From: Shweta Jacob Date: Sun, 5 Feb 2023 15:21:55 -0500 Subject: [PATCH 33/45] Add plotting example for snr audio metrics --- examples/plotting.py | 40 +++++++++++++++++++++++++++++++++++ src/torchmetrics/audio/snr.py | 4 ++-- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/examples/plotting.py b/examples/plotting.py index 31ffebf2b13..bc6e6bca241 100644 --- a/examples/plotting.py +++ b/examples/plotting.py @@ -94,6 +94,44 @@ def si_sdr_example(): return fig, ax +def snr_example(): + from torchmetrics.audio.snr import SignalNoiseRatio + + p = lambda: torch.randn(4) + t = lambda: torch.randn(4) + + # plot single value + metric = SignalNoiseRatio() + metric.update(p(), t()) + fig, ax = metric.plot() + + # plot multiple values + metric = SignalNoiseRatio() + vals = [metric(p(), t()) for _ in range(10)] + fig, ax = metric.plot(vals) + + return fig, ax + + +def si_snr_example(): + from torchmetrics.audio.snr import ScaleInvariantSignalNoiseRatio + + p = lambda: torch.randn(4) + t = lambda: torch.randn(4) + + # plot single value + metric = ScaleInvariantSignalNoiseRatio() + metric.update(p(), t()) + fig, ax = metric.plot() + + # plot multiple values + metric = ScaleInvariantSignalNoiseRatio() + vals = [metric(p(), t()) for _ in range(10)] + fig, ax = metric.plot(vals) + + return fig, ax + + def accuracy_example(): from torchmetrics.classification import MulticlassAccuracy @@ -163,6 +201,8 @@ def confusion_matrix_example(): "pit": pit_example, "sdr": sdr_example, "si-sdr": si_sdr_example, + "snr": snr_example, + "si-snr": si_snr_example, "mean_squared_error": mean_squared_error_example, "confusion_matrix": confusion_matrix_example, } diff --git a/src/torchmetrics/audio/snr.py b/src/torchmetrics/audio/snr.py index 57d88420585..6d86365efd1 100644 --- a/src/torchmetrics/audio/snr.py +++ b/src/torchmetrics/audio/snr.py @@ -64,7 +64,7 @@ class SignalNoiseRatio(Metric): higher_is_better: bool = True sum_snr: Tensor total: Tensor - plot_options: dict = {"lower_bound": -20.0, "upper_bound": 1.0} + plot_options: dict = {"lower_bound": -20.0, "upper_bound": 5.0} def __init__( self, @@ -170,7 +170,7 @@ class ScaleInvariantSignalNoiseRatio(Metric): sum_si_snr: Tensor total: Tensor higher_is_better = True - plot_options: dict = {"lower_bound": -20.0, "upper_bound": 1.0} + plot_options: dict = {"lower_bound": -20.0, "upper_bound": 10.0} def __init__( self, From aeeb40d49b38f70c4fbc0aeb28816856509d3178 Mon Sep 17 00:00:00 2001 From: Shweta Jacob Date: Sun, 5 Feb 2023 15:31:54 -0500 Subject: [PATCH 34/45] Add plot functionality for stoi audio metric --- src/torchmetrics/audio/stoi.py | 62 +++++++++++++++++++++++++- tests/unittests/utilities/test_plot.py | 7 +++ 2 files changed, 68 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/audio/stoi.py b/src/torchmetrics/audio/stoi.py index 2446e21c322..0107130b3b8 100644 --- a/src/torchmetrics/audio/stoi.py +++ b/src/torchmetrics/audio/stoi.py @@ -11,16 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Optional, Sequence, Union from torch import Tensor, tensor from torchmetrics.functional.audio.stoi import short_time_objective_intelligibility from torchmetrics.metric import Metric from torchmetrics.utilities.imports import _PYSTOI_AVAILABLE +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val __doctest_requires__ = {"ShortTimeObjectiveIntelligibility": ["pystoi"]} +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["SignalNoiseRatio.plot", "ScaleInvariantSignalNoiseRatio.plot"] + class ShortTimeObjectiveIntelligibility(Metric): r"""Calculates STOI (Short-Time Objective Intelligibility) metric for evaluating speech signals. Intelligibility @@ -72,6 +77,7 @@ class ShortTimeObjectiveIntelligibility(Metric): full_state_update: bool = False is_differentiable: bool = False higher_is_better: bool = True + plot_options: dict = {"lower_bound": -20.0, "upper_bound": 5.0} def __init__( self, @@ -103,3 +109,57 @@ def update(self, preds: Tensor, target: Tensor) -> None: def compute(self) -> Tensor: """Computes metric.""" return self.sum_stoi / self.total + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + fig: Figure object + ax: Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + Examples: + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility + >>> g = torch.manual_seed(1) + >>> preds = torch.randn(8000) + >>> target = torch.randn(8000) + >>> metric = ShortTimeObjectiveIntelligibility(8000, False) + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility + >>> metric = ShortTimeObjectiveIntelligibility(8000, False) + >>> g = torch.manual_seed(1) + >>> preds = torch.randn(8000) + >>> target = torch.randn(8000) + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(preds, target)) + >>> fig_, ax_ = metric.plot(values) + """ + val = val or self.compute() + fig, ax = plot_single_or_multi_val( + val, ax=ax, higher_is_better=self.higher_is_better, **self.plot_options, name=self.__class__.__name__ + ) + return fig, ax diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 28f0d3b3478..09d4ee3df40 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -25,6 +25,7 @@ signal_distortion_ratio, signal_noise_ratio, ) +from torchmetrics.functional.audio import short_time_objective_intelligibility from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality from torchmetrics.functional.audio.pit import permutation_invariant_training from torchmetrics.functional.classification.accuracy import binary_accuracy, multiclass_accuracy @@ -87,6 +88,12 @@ lambda: torch.randn(4), id="scale_invariant_signal_noise_ratio", ), + pytest.param( + partial(short_time_objective_intelligibility, fs=8000, extended=False), + lambda: torch.randn(8000), + lambda: torch.randn(8000), + id="short_time_objective_intelligibility", + ) ], ) @pytest.mark.parametrize("num_vals", [1, 5, 10]) From fb1fd2708125faf9ac022f2127a7c4775071645a Mon Sep 17 00:00:00 2001 From: Shweta Jacob Date: Sun, 5 Feb 2023 15:33:53 -0500 Subject: [PATCH 35/45] Add plotting example for stoi audio metric --- examples/plotting.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/examples/plotting.py b/examples/plotting.py index bc6e6bca241..a6b3761430e 100644 --- a/examples/plotting.py +++ b/examples/plotting.py @@ -132,6 +132,25 @@ def si_snr_example(): return fig, ax +def stoi_example(): + from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility + + p = lambda: torch.randn(8000) + t = lambda: torch.randn(8000) + + # plot single value + metric = ShortTimeObjectiveIntelligibility(8000, False) + metric.update(p(), t()) + fig, ax = metric.plot() + + # plot multiple values + metric = ShortTimeObjectiveIntelligibility(8000, False) + vals = [metric(p(), t()) for _ in range(10)] + fig, ax = metric.plot(vals) + + return fig, ax + + def accuracy_example(): from torchmetrics.classification import MulticlassAccuracy @@ -203,6 +222,7 @@ def confusion_matrix_example(): "si-sdr": si_sdr_example, "snr": snr_example, "si-snr": si_snr_example, + "stoi": stoi_example, "mean_squared_error": mean_squared_error_example, "confusion_matrix": confusion_matrix_example, } From a7da32d2534927f7fd227bf0a397ef6d4c337c61 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 5 Feb 2023 20:34:46 +0000 Subject: [PATCH 36/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/audio/stoi.py | 3 +-- tests/unittests/utilities/test_plot.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/torchmetrics/audio/stoi.py b/src/torchmetrics/audio/stoi.py index 0107130b3b8..2288995e918 100644 --- a/src/torchmetrics/audio/stoi.py +++ b/src/torchmetrics/audio/stoi.py @@ -17,8 +17,7 @@ from torchmetrics.functional.audio.stoi import short_time_objective_intelligibility from torchmetrics.metric import Metric -from torchmetrics.utilities.imports import _PYSTOI_AVAILABLE -from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _PYSTOI_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val __doctest_requires__ = {"ShortTimeObjectiveIntelligibility": ["pystoi"]} diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 09d4ee3df40..2a63c96a88e 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -93,7 +93,7 @@ lambda: torch.randn(8000), lambda: torch.randn(8000), id="short_time_objective_intelligibility", - ) + ), ], ) @pytest.mark.parametrize("num_vals", [1, 5, 10]) From fc2a1fffd775261a770d7ec20dce1273781ad0e3 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Mon, 6 Feb 2023 11:47:12 +0100 Subject: [PATCH 37/45] Update src/torchmetrics/audio/stoi.py --- src/torchmetrics/audio/stoi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/audio/stoi.py b/src/torchmetrics/audio/stoi.py index 2288995e918..37e0c8ce9d9 100644 --- a/src/torchmetrics/audio/stoi.py +++ b/src/torchmetrics/audio/stoi.py @@ -23,7 +23,7 @@ __doctest_requires__ = {"ShortTimeObjectiveIntelligibility": ["pystoi"]} if not _MATPLOTLIB_AVAILABLE: - __doctest_skip__ = ["SignalNoiseRatio.plot", "ScaleInvariantSignalNoiseRatio.plot"] + __doctest_skip__ = ["ShortTimeObjectiveIntelligibility.plot"] class ShortTimeObjectiveIntelligibility(Metric): From 5cadc8cfe73115af6f56f1b722cf8cca19fd7c58 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 6 Feb 2023 11:27:51 +0000 Subject: [PATCH 38/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/audio/pesq.py | 1 - src/torchmetrics/audio/pit.py | 1 - src/torchmetrics/audio/sdr.py | 2 -- src/torchmetrics/audio/snr.py | 2 -- src/torchmetrics/audio/stoi.py | 1 - 5 files changed, 7 deletions(-) diff --git a/src/torchmetrics/audio/pesq.py b/src/torchmetrics/audio/pesq.py index 8962939bdf1..67b9fecf37c 100644 --- a/src/torchmetrics/audio/pesq.py +++ b/src/torchmetrics/audio/pesq.py @@ -144,7 +144,6 @@ def plot( If `matplotlib` is not installed Examples: - .. plot:: :scale: 75 diff --git a/src/torchmetrics/audio/pit.py b/src/torchmetrics/audio/pit.py index 5a47f44237a..25ae23ff1c9 100644 --- a/src/torchmetrics/audio/pit.py +++ b/src/torchmetrics/audio/pit.py @@ -118,7 +118,6 @@ def plot( If `matplotlib` is not installed Examples: - .. plot:: :scale: 75 diff --git a/src/torchmetrics/audio/sdr.py b/src/torchmetrics/audio/sdr.py index 68555534ad2..76769784f97 100644 --- a/src/torchmetrics/audio/sdr.py +++ b/src/torchmetrics/audio/sdr.py @@ -136,7 +136,6 @@ def plot( If `matplotlib` is not installed Examples: - .. plot:: :scale: 75 @@ -244,7 +243,6 @@ def plot( If `matplotlib` is not installed Examples: - .. plot:: :scale: 75 diff --git a/src/torchmetrics/audio/snr.py b/src/torchmetrics/audio/snr.py index c6ccd004efb..75c4ed53136 100644 --- a/src/torchmetrics/audio/snr.py +++ b/src/torchmetrics/audio/snr.py @@ -107,7 +107,6 @@ def plot( If `matplotlib` is not installed Examples: - .. plot:: :scale: 75 @@ -211,7 +210,6 @@ def plot( If `matplotlib` is not installed Examples: - .. plot:: :scale: 75 diff --git a/src/torchmetrics/audio/stoi.py b/src/torchmetrics/audio/stoi.py index 4d83bb19b6d..8f548f2cfd8 100644 --- a/src/torchmetrics/audio/stoi.py +++ b/src/torchmetrics/audio/stoi.py @@ -128,7 +128,6 @@ def plot( If `matplotlib` is not installed Examples: - .. plot:: :scale: 75 From d0b904d75f0b4d8c055836fa014fa7bf28b80aeb Mon Sep 17 00:00:00 2001 From: Shweta Jacob Date: Mon, 6 Feb 2023 13:19:30 -0500 Subject: [PATCH 39/45] Update return type in docstring for plot in audio metrics --- src/torchmetrics/audio/pesq.py | 3 +-- src/torchmetrics/audio/pit.py | 3 +-- src/torchmetrics/audio/sdr.py | 6 ++---- src/torchmetrics/audio/snr.py | 6 ++---- src/torchmetrics/audio/stoi.py | 3 +-- 5 files changed, 7 insertions(+), 14 deletions(-) diff --git a/src/torchmetrics/audio/pesq.py b/src/torchmetrics/audio/pesq.py index c1dff51871c..7dc2d995d13 100644 --- a/src/torchmetrics/audio/pesq.py +++ b/src/torchmetrics/audio/pesq.py @@ -136,8 +136,7 @@ def plot( ax: An matplotlib axis object. If provided will add plot to that axis Returns: - fig: Figure object - ax: Axes object + Figure and Axes object Raises: ModuleNotFoundError: diff --git a/src/torchmetrics/audio/pit.py b/src/torchmetrics/audio/pit.py index 1f2f5fbe8d5..d28a4d94d2c 100644 --- a/src/torchmetrics/audio/pit.py +++ b/src/torchmetrics/audio/pit.py @@ -110,8 +110,7 @@ def plot( ax: An matplotlib axis object. If provided will add plot to that axis Returns: - fig: Figure object - ax: Axes object + Figure and Axes object Raises: ModuleNotFoundError: diff --git a/src/torchmetrics/audio/sdr.py b/src/torchmetrics/audio/sdr.py index f7f85ff2ee8..df3f5172630 100644 --- a/src/torchmetrics/audio/sdr.py +++ b/src/torchmetrics/audio/sdr.py @@ -128,8 +128,7 @@ def plot( ax: An matplotlib axis object. If provided will add plot to that axis Returns: - fig: Figure object - ax: Axes object + Figure and Axes object Raises: ModuleNotFoundError: @@ -235,8 +234,7 @@ def plot( ax: An matplotlib axis object. If provided will add plot to that axis Returns: - fig: Figure object - ax: Axes object + Figure and Axes object Raises: ModuleNotFoundError: diff --git a/src/torchmetrics/audio/snr.py b/src/torchmetrics/audio/snr.py index 0ff7e9640c0..5baf1858c53 100644 --- a/src/torchmetrics/audio/snr.py +++ b/src/torchmetrics/audio/snr.py @@ -99,8 +99,7 @@ def plot( ax: An matplotlib axis object. If provided will add plot to that axis Returns: - fig: Figure object - ax: Axes object + Figure and Axes object Raises: ModuleNotFoundError: @@ -202,8 +201,7 @@ def plot( ax: An matplotlib axis object. If provided will add plot to that axis Returns: - fig: Figure object - ax: Axes object + Figure and Axes object Raises: ModuleNotFoundError: diff --git a/src/torchmetrics/audio/stoi.py b/src/torchmetrics/audio/stoi.py index ffd6f804422..3aa55fbd860 100644 --- a/src/torchmetrics/audio/stoi.py +++ b/src/torchmetrics/audio/stoi.py @@ -120,8 +120,7 @@ def plot( ax: An matplotlib axis object. If provided will add plot to that axis Returns: - fig: Figure object - ax: Axes object + Figure and Axes object Raises: ModuleNotFoundError: From 47dffc78ba0e7b0ca90ef21018b64f654c75aaa4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 7 Feb 2023 02:37:08 +0000 Subject: [PATCH 40/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/unittests/utilities/test_plot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 2fae61d3069..11a757f6d1e 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -130,7 +130,7 @@ def test_single_multi_val_plotter_pit(metric, preds, target, num_vals): @pytest.mark.parametrize( - "metric, preds, target", + ("metric", "preds", "target"), [ pytest.param( binary_confusion_matrix, From c7f8964cba71a3e5686c4906bc7959853303875d Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 10 Feb 2023 10:57:35 +0100 Subject: [PATCH 41/45] missing docstrings --- examples/plotting.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/examples/plotting.py b/examples/plotting.py index 845ea52cac8..cc85c091ebf 100644 --- a/examples/plotting.py +++ b/examples/plotting.py @@ -18,6 +18,7 @@ def pesq_example(): + """Plot PESQ audio example.""" from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality p = lambda: torch.randn(8000) @@ -37,6 +38,7 @@ def pesq_example(): def pit_example(): + """Plot PIT audio example.""" from torchmetrics.audio.pit import PermutationInvariantTraining from torchmetrics.functional import scale_invariant_signal_noise_ratio @@ -57,6 +59,7 @@ def pit_example(): def sdr_example(): + """Plot SDR audio example.""" from torchmetrics.audio.sdr import SignalDistortionRatio p = lambda: torch.randn(8000) @@ -76,6 +79,7 @@ def sdr_example(): def si_sdr_example(): + """Plot SI-SDR audio example.""" from torchmetrics.audio.sdr import ScaleInvariantSignalDistortionRatio p = lambda: torch.randn(5) @@ -95,6 +99,7 @@ def si_sdr_example(): def snr_example(): + """Plot SNR audio example.""" from torchmetrics.audio.snr import SignalNoiseRatio p = lambda: torch.randn(4) @@ -114,6 +119,7 @@ def snr_example(): def si_snr_example(): + """Plot SI-SNR example.""" from torchmetrics.audio.snr import ScaleInvariantSignalNoiseRatio p = lambda: torch.randn(4) @@ -133,6 +139,7 @@ def si_snr_example(): def stoi_example(): + """Plot STOI example.""" from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility p = lambda: torch.randn(8000) From 6f5bcbbb19dfca3627bb4b08d8557924aa23bac1 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 10 Feb 2023 11:19:34 +0100 Subject: [PATCH 42/45] fix docs --- src/torchmetrics/audio/pesq.py | 1 - src/torchmetrics/audio/pit.py | 1 - src/torchmetrics/audio/sdr.py | 2 -- src/torchmetrics/audio/snr.py | 2 -- src/torchmetrics/audio/stoi.py | 1 - 5 files changed, 7 deletions(-) diff --git a/src/torchmetrics/audio/pesq.py b/src/torchmetrics/audio/pesq.py index 7dc2d995d13..737a82970bb 100644 --- a/src/torchmetrics/audio/pesq.py +++ b/src/torchmetrics/audio/pesq.py @@ -142,7 +142,6 @@ def plot( ModuleNotFoundError: If `matplotlib` is not installed - Examples: .. plot:: :scale: 75 diff --git a/src/torchmetrics/audio/pit.py b/src/torchmetrics/audio/pit.py index d28a4d94d2c..915d4e177d0 100644 --- a/src/torchmetrics/audio/pit.py +++ b/src/torchmetrics/audio/pit.py @@ -116,7 +116,6 @@ def plot( ModuleNotFoundError: If `matplotlib` is not installed - Examples: .. plot:: :scale: 75 diff --git a/src/torchmetrics/audio/sdr.py b/src/torchmetrics/audio/sdr.py index df3f5172630..63c616dfd23 100644 --- a/src/torchmetrics/audio/sdr.py +++ b/src/torchmetrics/audio/sdr.py @@ -134,7 +134,6 @@ def plot( ModuleNotFoundError: If `matplotlib` is not installed - Examples: .. plot:: :scale: 75 @@ -240,7 +239,6 @@ def plot( ModuleNotFoundError: If `matplotlib` is not installed - Examples: .. plot:: :scale: 75 diff --git a/src/torchmetrics/audio/snr.py b/src/torchmetrics/audio/snr.py index 5baf1858c53..68b2e823b76 100644 --- a/src/torchmetrics/audio/snr.py +++ b/src/torchmetrics/audio/snr.py @@ -105,7 +105,6 @@ def plot( ModuleNotFoundError: If `matplotlib` is not installed - Examples: .. plot:: :scale: 75 @@ -207,7 +206,6 @@ def plot( ModuleNotFoundError: If `matplotlib` is not installed - Examples: .. plot:: :scale: 75 diff --git a/src/torchmetrics/audio/stoi.py b/src/torchmetrics/audio/stoi.py index 3aa55fbd860..56167d815a4 100644 --- a/src/torchmetrics/audio/stoi.py +++ b/src/torchmetrics/audio/stoi.py @@ -126,7 +126,6 @@ def plot( ModuleNotFoundError: If `matplotlib` is not installed - Examples: .. plot:: :scale: 75 From 0062a1ec627e9d848e5394d19defcef5c17b6b58 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 10 Feb 2023 12:06:10 +0100 Subject: [PATCH 43/45] try fixing typing --- src/torchmetrics/metric.py | 2 +- src/torchmetrics/utilities/plot.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index 70e191273fc..8057ccf6335 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -561,7 +561,7 @@ def compute(self) -> Any: distributed backend. """ - def plot(self, *_: Any, **__: Any) -> None: + def plot(self, *_: Any, **__: Any) -> Any: """Override this method plot the metric value.""" raise NotImplementedError diff --git a/src/torchmetrics/utilities/plot.py b/src/torchmetrics/utilities/plot.py index 15b84f4533c..df08c0e5a8d 100644 --- a/src/torchmetrics/utilities/plot.py +++ b/src/torchmetrics/utilities/plot.py @@ -13,7 +13,7 @@ # limitations under the License. from itertools import product from math import ceil, floor, sqrt -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -41,7 +41,7 @@ def _error_on_missing_matplotlib() -> None: def plot_single_or_multi_val( - val: Union[Tensor, List[Tensor]], + val: Union[Tensor, Sequence[Tensor]], ax: Optional[_AX_TYPE] = None, # type: ignore[valid-type] higher_is_better: Optional[bool] = None, lower_bound: Optional[float] = None, From 6248cd574f957c8f9644d863102ce1a9e24711e8 Mon Sep 17 00:00:00 2001 From: Jirka Date: Fri, 17 Feb 2023 04:28:42 +0100 Subject: [PATCH 44/45] typing --- src/torchmetrics/audio/pesq.py | 2 +- src/torchmetrics/audio/pit.py | 2 +- src/torchmetrics/audio/sdr.py | 2 +- src/torchmetrics/audio/snr.py | 2 +- src/torchmetrics/audio/stoi.py | 2 +- src/torchmetrics/utilities/plot.py | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/torchmetrics/audio/pesq.py b/src/torchmetrics/audio/pesq.py index 737a82970bb..ce8952d5356 100644 --- a/src/torchmetrics/audio/pesq.py +++ b/src/torchmetrics/audio/pesq.py @@ -126,7 +126,7 @@ def compute(self) -> Tensor: return self.sum_pesq / self.total def plot( - self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. diff --git a/src/torchmetrics/audio/pit.py b/src/torchmetrics/audio/pit.py index 915d4e177d0..cd68d56348b 100644 --- a/src/torchmetrics/audio/pit.py +++ b/src/torchmetrics/audio/pit.py @@ -100,7 +100,7 @@ def compute(self) -> Tensor: return self.sum_pit_metric / self.total def plot( - self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. diff --git a/src/torchmetrics/audio/sdr.py b/src/torchmetrics/audio/sdr.py index 63c616dfd23..4b0bc7ad7cb 100644 --- a/src/torchmetrics/audio/sdr.py +++ b/src/torchmetrics/audio/sdr.py @@ -118,7 +118,7 @@ def compute(self) -> Tensor: return self.sum_sdr / self.total def plot( - self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. diff --git a/src/torchmetrics/audio/snr.py b/src/torchmetrics/audio/snr.py index 57325143606..08b044f331c 100644 --- a/src/torchmetrics/audio/snr.py +++ b/src/torchmetrics/audio/snr.py @@ -190,7 +190,7 @@ def compute(self) -> Tensor: return self.sum_si_snr / self.total def plot( - self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. diff --git a/src/torchmetrics/audio/stoi.py b/src/torchmetrics/audio/stoi.py index 56167d815a4..75130a68159 100644 --- a/src/torchmetrics/audio/stoi.py +++ b/src/torchmetrics/audio/stoi.py @@ -110,7 +110,7 @@ def compute(self) -> Tensor: return self.sum_stoi / self.total def plot( - self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None ) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. diff --git a/src/torchmetrics/utilities/plot.py b/src/torchmetrics/utilities/plot.py index df08c0e5a8d..70fa000c99f 100644 --- a/src/torchmetrics/utilities/plot.py +++ b/src/torchmetrics/utilities/plot.py @@ -81,7 +81,7 @@ def plot_single_or_multi_val( label = f"{legend_name} {i}" if legend_name else f"{i}" ax.plot(i, v.detach().cpu(), marker="o", markersize=10, linestyle="None", label=label) else: - val = torch.stack(val, 0) + val = torch.stack(list(val), 0) multi_series = val.ndim != 1 val = val.T if multi_series else val.unsqueeze(0) for i, v in enumerate(val): From e72df04bfa47dd6df9a2ae622998717d9e47ea7c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 17 Feb 2023 03:29:17 +0000 Subject: [PATCH 45/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/audio/pesq.py | 4 +--- src/torchmetrics/audio/pit.py | 4 +--- src/torchmetrics/audio/sdr.py | 4 +--- src/torchmetrics/audio/snr.py | 4 +--- src/torchmetrics/audio/stoi.py | 4 +--- 5 files changed, 5 insertions(+), 15 deletions(-) diff --git a/src/torchmetrics/audio/pesq.py b/src/torchmetrics/audio/pesq.py index ce8952d5356..03fe6ff7c77 100644 --- a/src/torchmetrics/audio/pesq.py +++ b/src/torchmetrics/audio/pesq.py @@ -125,9 +125,7 @@ def compute(self) -> Tensor: """Compute metric.""" return self.sum_pesq / self.total - def plot( - self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None - ) -> _PLOT_OUT_TYPE: + def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. Args: diff --git a/src/torchmetrics/audio/pit.py b/src/torchmetrics/audio/pit.py index cd68d56348b..689550ed821 100644 --- a/src/torchmetrics/audio/pit.py +++ b/src/torchmetrics/audio/pit.py @@ -99,9 +99,7 @@ def compute(self) -> Tensor: """Compute metric.""" return self.sum_pit_metric / self.total - def plot( - self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None - ) -> _PLOT_OUT_TYPE: + def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. Args: diff --git a/src/torchmetrics/audio/sdr.py b/src/torchmetrics/audio/sdr.py index 4b0bc7ad7cb..69972cee1d8 100644 --- a/src/torchmetrics/audio/sdr.py +++ b/src/torchmetrics/audio/sdr.py @@ -117,9 +117,7 @@ def compute(self) -> Tensor: """Compute metric.""" return self.sum_sdr / self.total - def plot( - self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None - ) -> _PLOT_OUT_TYPE: + def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. Args: diff --git a/src/torchmetrics/audio/snr.py b/src/torchmetrics/audio/snr.py index 08b044f331c..c09681755a9 100644 --- a/src/torchmetrics/audio/snr.py +++ b/src/torchmetrics/audio/snr.py @@ -189,9 +189,7 @@ def compute(self) -> Tensor: """Compute metric.""" return self.sum_si_snr / self.total - def plot( - self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None - ) -> _PLOT_OUT_TYPE: + def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. Args: diff --git a/src/torchmetrics/audio/stoi.py b/src/torchmetrics/audio/stoi.py index 75130a68159..34db6e956b3 100644 --- a/src/torchmetrics/audio/stoi.py +++ b/src/torchmetrics/audio/stoi.py @@ -109,9 +109,7 @@ def compute(self) -> Tensor: """Compute metric.""" return self.sum_stoi / self.total - def plot( - self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None - ) -> _PLOT_OUT_TYPE: + def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. Args: