Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add plot() functionality to audio metrics #1434

Merged
merged 61 commits into from
Feb 17, 2023
Merged
Show file tree
Hide file tree
Changes from 52 commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
e08e1c1
Add plot functionality to PerceptualEvaluationSpeechQuality
shhs29 Jan 8, 2023
5d301e8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 8, 2023
3ff8372
Update CHANGELOG
shhs29 Jan 8, 2023
c785a94
Remove test for PerceptualEvaluationSpeechQuality plot
shhs29 Jan 8, 2023
c63e87c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 8, 2023
8c79aba
Skip doctest if matplotlib is not available
shhs29 Jan 8, 2023
4b51d13
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 8, 2023
428c4ba
Update pesq plot docs
shhs29 Jan 8, 2023
bd73378
Add test for pesq plotting
shhs29 Jan 8, 2023
88b2f2b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 8, 2023
006797a
Update docs
shhs29 Jan 8, 2023
38a86c1
Update docs
shhs29 Jan 8, 2023
f4eca15
Add pesq plotting example
shhs29 Jan 9, 2023
35b9adb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2023
2effd88
Add pesq to init file
shhs29 Jan 9, 2023
e8e505d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2023
0829b4a
Merge branch 'master' into add-plot-for-audio-metrics
shhs29 Jan 10, 2023
64c894a
Merge branch 'master' into add-plot-for-audio-metrics
shhs29 Jan 13, 2023
58e979a
Merge branch 'master' into add-plot-for-audio-metrics
shhs29 Jan 18, 2023
2a51f12
Merge branch 'master' into add-plot-for-audio-metrics
shhs29 Jan 28, 2023
6d113ed
Merge branch 'master' into add-plot-for-audio-metrics
shhs29 Feb 2, 2023
ffd4d7d
fix
SkafteNicki Feb 3, 2023
d15501e
add requirement
SkafteNicki Feb 3, 2023
9463a6d
fix doctest
SkafteNicki Feb 3, 2023
3bec023
gh: update templates (#1477)Co-authored-by: pre-commit-ci[bot] <66853…
Borda Feb 3, 2023
a05a31b
Merge branch 'master' into add-plot-for-audio-metrics
SkafteNicki Feb 3, 2023
35b9d39
Merge branch 'master' into add-plot-for-audio-metrics
SkafteNicki Feb 3, 2023
40af3a2
Add plot function for PermutationInvariantTraining
shhs29 Feb 3, 2023
01ddadc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 3, 2023
6b81ad7
Update docstring for plot in PermutationInvariantTraining
shhs29 Feb 4, 2023
32ef72b
Add doctest requirement for pit
shhs29 Feb 4, 2023
77664e8
Fix docstring of pit plot
shhs29 Feb 4, 2023
c83d629
Fix plot test for pit
shhs29 Feb 4, 2023
52af69c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 4, 2023
f96aca8
Add plot functionality for sdr audio metrics
shhs29 Feb 5, 2023
0540c50
Add plot examples for sdr audio metrics
shhs29 Feb 5, 2023
7df844d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 5, 2023
6b252f9
Add plot functionality for snr audio metrics
shhs29 Feb 5, 2023
3f7abc0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 5, 2023
ec04abc
Add plotting example for snr audio metrics
shhs29 Feb 5, 2023
aeeb40d
Add plot functionality for stoi audio metric
shhs29 Feb 5, 2023
fb1fd27
Add plotting example for stoi audio metric
shhs29 Feb 5, 2023
a7da32d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 5, 2023
fc2a1ff
Update src/torchmetrics/audio/stoi.py
SkafteNicki Feb 6, 2023
422e8cc
Merge branch 'master' into add-plot-for-audio-metrics
Borda Feb 6, 2023
5cadc8c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 6, 2023
51b51e3
Merge branch 'master' into add-plot-for-audio-metrics
shhs29 Feb 6, 2023
d0b904d
Update return type in docstring for plot in audio metrics
shhs29 Feb 6, 2023
6670851
Merge branch 'master' into add-plot-for-audio-metrics
Borda Feb 7, 2023
47dffc7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 7, 2023
b3bc857
Merge branch 'master' into add-plot-for-audio-metrics
shhs29 Feb 7, 2023
c7f8964
missing docstrings
SkafteNicki Feb 10, 2023
6f5bcbb
fix docs
SkafteNicki Feb 10, 2023
0062a1e
try fixing typing
SkafteNicki Feb 10, 2023
d374b57
Merge branch 'master' into add-plot-for-audio-metrics
mergify[bot] Feb 13, 2023
6baeb64
Merge branch 'master' into add-plot-for-audio-metrics
mergify[bot] Feb 13, 2023
dd75eef
Merge branch 'master' into add-plot-for-audio-metrics
mergify[bot] Feb 13, 2023
d5e0df3
Merge branch 'master' into add-plot-for-audio-metrics
mergify[bot] Feb 14, 2023
f11660f
Merge branch 'master' into add-plot-for-audio-metrics
Borda Feb 17, 2023
6248cd5
typing
Borda Feb 17, 2023
e72df04
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 17, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
)


- Added support for plotting of audio metrics through `.plot()` method ([#1434](https://github.com/Lightning-AI/metrics/pull/1434))


- Added `classes` to output from `MAP` metric ([#1419](https://github.com/Lightning-AI/metrics/pull/1419))


Expand Down
148 changes: 148 additions & 0 deletions examples/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,147 @@
import torch


def pesq_example():
"""Plot PESQ audio example."""
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality

p = lambda: torch.randn(8000)
t = lambda: torch.randn(8000)

# plot single value
metric = PerceptualEvaluationSpeechQuality(8000, "nb")
metric.update(p(), t())
fig, ax = metric.plot()

# plot multiple values
metric = PerceptualEvaluationSpeechQuality(16000, "wb")
vals = [metric(p(), t()) for _ in range(10)]
fig, ax = metric.plot(vals)

return fig, ax


def pit_example():
"""Plot PIT audio example."""
from torchmetrics.audio.pit import PermutationInvariantTraining
from torchmetrics.functional import scale_invariant_signal_noise_ratio

p = lambda: torch.randn(3, 2, 5)
t = lambda: torch.randn(3, 2, 5)

# plot single value
metric = PermutationInvariantTraining(scale_invariant_signal_noise_ratio, "max")
metric.update(p(), t())
fig, ax = metric.plot()

# plot multiple values
metric = PermutationInvariantTraining(scale_invariant_signal_noise_ratio, "max")
vals = [metric(p(), t()) for _ in range(10)]
fig, ax = metric.plot(vals)

return fig, ax


def sdr_example():
"""Plot SDR audio example."""
from torchmetrics.audio.sdr import SignalDistortionRatio

p = lambda: torch.randn(8000)
t = lambda: torch.randn(8000)

# plot single value
metric = SignalDistortionRatio()
metric.update(p(), t())
fig, ax = metric.plot()

# plot multiple values
metric = SignalDistortionRatio()
vals = [metric(p(), t()) for _ in range(10)]
fig, ax = metric.plot(vals)

return fig, ax


def si_sdr_example():
"""Plot SI-SDR audio example."""
from torchmetrics.audio.sdr import ScaleInvariantSignalDistortionRatio

p = lambda: torch.randn(5)
t = lambda: torch.randn(5)

# plot single value
metric = ScaleInvariantSignalDistortionRatio()
metric.update(p(), t())
fig, ax = metric.plot()

# plot multiple values
metric = ScaleInvariantSignalDistortionRatio()
vals = [metric(p(), t()) for _ in range(10)]
fig, ax = metric.plot(vals)

return fig, ax


def snr_example():
"""Plot SNR audio example."""
from torchmetrics.audio.snr import SignalNoiseRatio

p = lambda: torch.randn(4)
t = lambda: torch.randn(4)

# plot single value
metric = SignalNoiseRatio()
metric.update(p(), t())
fig, ax = metric.plot()

# plot multiple values
metric = SignalNoiseRatio()
vals = [metric(p(), t()) for _ in range(10)]
fig, ax = metric.plot(vals)

return fig, ax


def si_snr_example():
"""Plot SI-SNR example."""
from torchmetrics.audio.snr import ScaleInvariantSignalNoiseRatio

p = lambda: torch.randn(4)
t = lambda: torch.randn(4)

# plot single value
metric = ScaleInvariantSignalNoiseRatio()
metric.update(p(), t())
fig, ax = metric.plot()

# plot multiple values
metric = ScaleInvariantSignalNoiseRatio()
vals = [metric(p(), t()) for _ in range(10)]
fig, ax = metric.plot(vals)

return fig, ax


def stoi_example():
"""Plot STOI example."""
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility

p = lambda: torch.randn(8000)
t = lambda: torch.randn(8000)

# plot single value
metric = ShortTimeObjectiveIntelligibility(8000, False)
metric.update(p(), t())
fig, ax = metric.plot()

# plot multiple values
metric = ShortTimeObjectiveIntelligibility(8000, False)
vals = [metric(p(), t()) for _ in range(10)]
fig, ax = metric.plot(vals)

return fig, ax


def accuracy_example():
"""Plot Accuracy example."""
from torchmetrics.classification import MulticlassAccuracy
Expand Down Expand Up @@ -85,6 +226,13 @@ def confusion_matrix_example():

metrics_func = {
"accuracy": accuracy_example,
"pesq": pesq_example,
"pit": pit_example,
"sdr": sdr_example,
"si-sdr": si_sdr_example,
"snr": snr_example,
"si-snr": si_snr_example,
"stoi": stoi_example,
"mean_squared_error": mean_squared_error_example,
"confusion_matrix": confusion_matrix_example,
}
Expand Down
1 change: 1 addition & 0 deletions requirements/docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ sphinx-copybutton>=0.3
# integrations
-r integrate.txt
-r visual.txt
-r audio.txt
55 changes: 53 additions & 2 deletions src/torchmetrics/audio/pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Any, Optional, Sequence, Union

from torch import Tensor, tensor

from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _PESQ_AVAILABLE
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _PESQ_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val

__doctest_requires__ = {"PerceptualEvaluationSpeechQuality": ["pesq"]}

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["PerceptualEvaluationSpeechQuality.plot"]


class PerceptualEvaluationSpeechQuality(Metric):
"""Calculate `Perceptual Evaluation of Speech Quality`_ (PESQ). It's a recognized industry standard for audio
Expand Down Expand Up @@ -80,6 +84,7 @@ class PerceptualEvaluationSpeechQuality(Metric):
full_state_update: bool = False
is_differentiable: bool = False
higher_is_better: bool = True
plot_options: dict = {"lower_bound": 1.0, "upper_bound": 4.5}

def __init__(
self,
Expand Down Expand Up @@ -119,3 +124,49 @@ def update(self, preds: Tensor, target: Tensor) -> None:
def compute(self) -> Tensor:
"""Compute metric."""
return self.sum_pesq / self.total

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis

Returns:
Figure and Axes object

Raises:
ModuleNotFoundError:
If `matplotlib` is not installed

Examples:
.. plot::
:scale: 75

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
>>> metric = PerceptualEvaluationSpeechQuality(8000, 'nb')
>>> metric.update(torch.rand(8000), torch.rand(8000))
>>> fig_, ax_ = metric.plot()

.. plot::
:scale: 75

>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
>>> metric = PerceptualEvaluationSpeechQuality(8000, 'nb')
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(torch.rand(8000), torch.rand(8000)))
>>> fig_, ax_ = metric.plot(values)
"""
val = val or self.compute()
fig, ax = plot_single_or_multi_val(
val, ax=ax, higher_is_better=self.higher_is_better, **self.plot_options, name=self.__class__.__name__
)
return fig, ax
62 changes: 61 additions & 1 deletion src/torchmetrics/audio/pit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict
from typing import Any, Callable, Dict, Optional, Sequence, Union

from torch import Tensor, tensor
from typing_extensions import Literal

from torchmetrics.functional.audio.pit import permutation_invariant_training
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val

__doctest_requires__ = {"PermutationInvariantTraining": ["pit"]}

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["PermutationInvariantTraining.plot"]


class PermutationInvariantTraining(Metric):
Expand Down Expand Up @@ -60,6 +67,7 @@ class PermutationInvariantTraining(Metric):
is_differentiable: bool = True
sum_pit_metric: Tensor
total: Tensor
plot_options: dict = {"lower_bound": -10.0, "upper_bound": 1.0}

def __init__(
self,
Expand Down Expand Up @@ -90,3 +98,55 @@ def update(self, preds: Tensor, target: Tensor) -> None:
def compute(self) -> Tensor:
"""Compute metric."""
return self.sum_pit_metric / self.total

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis

Returns:
Figure and Axes object

Raises:
ModuleNotFoundError:
If `matplotlib` is not installed

Examples:
.. plot::
:scale: 75

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.audio.pit import PermutationInvariantTraining
>>> from torchmetrics.functional import scale_invariant_signal_noise_ratio
>>> preds = torch.randn(3, 2, 5) # [batch, spk, time]
>>> target = torch.randn(3, 2, 5) # [batch, spk, time]
>>> metric = PermutationInvariantTraining(scale_invariant_signal_noise_ratio, 'max')
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()

.. plot::
:scale: 75

>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.audio.pit import PermutationInvariantTraining
>>> from torchmetrics.functional import scale_invariant_signal_noise_ratio
>>> preds = torch.randn(3, 2, 5) # [batch, spk, time]
>>> target = torch.randn(3, 2, 5) # [batch, spk, time]
>>> metric = PermutationInvariantTraining(scale_invariant_signal_noise_ratio, 'max')
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(preds, target))
>>> fig_, ax_ = metric.plot(values)
"""
val = val or self.compute()
fig, ax = plot_single_or_multi_val(
val, ax=ax, higher_is_better=self.higher_is_better, **self.plot_options, name=self.__class__.__name__
)
return fig, ax
Loading