Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NISQA metric #2792

Merged
merged 27 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
a030e32
Add NISQA metric
philgzl Oct 15, 2024
eeda544
Add tests
philgzl Oct 21, 2024
434f161
Merge branch 'master' into nisqa
philgzl Oct 21, 2024
8c04f27
Remove sg_execution_times.rst
philgzl Oct 21, 2024
1138e63
Merge branch 'master' into nisqa
philgzl Oct 21, 2024
f74d0d6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 21, 2024
b7027b2
Fix weights path
philgzl Oct 21, 2024
19da6a3
Fix missing docstrings
philgzl Oct 21, 2024
6e3a913
Update CHANGELOG.md
philgzl Oct 21, 2024
aef1ffd
Fix typing
philgzl Oct 21, 2024
ab83d69
Fix can't pickle local object error
philgzl Oct 21, 2024
b2ffd8d
Fix can't pickle local object error
philgzl Oct 21, 2024
03eebe4
Fix DDP tests
philgzl Oct 22, 2024
72517b5
Merge branch 'master' into nisqa
philgzl Oct 22, 2024
8724e44
Increase atol
philgzl Oct 22, 2024
1670ba4
Merge branch 'master' into nisqa
Borda Oct 22, 2024
dc0c37a
Add docstrings and increase atol
philgzl Oct 22, 2024
ab2f2a3
Add link
philgzl Oct 22, 2024
3f45401
Fix dimension order in docstrings
philgzl Oct 22, 2024
0f25ae4
Merge branch 'master' into nisqa
SkafteNicki Oct 22, 2024
9867a6d
Update src/torchmetrics/audio/nisqa.py
philgzl Oct 22, 2024
10f2c65
Update src/torchmetrics/functional/audio/nisqa.py
philgzl Oct 22, 2024
48eb717
Update tests/unittests/audio/test_nisqa.py
philgzl Oct 22, 2024
64195d7
Merge branch 'master' into nisqa
SkafteNicki Oct 22, 2024
9f9f28e
Merge branch 'master' into nisqa
Borda Oct 29, 2024
23e0d6c
add additional copyright
SkafteNicki Oct 30, 2024
2b34960
Merge branch 'master' into nisqa
SkafteNicki Oct 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `NormalizedRootMeanSquaredError` metric to regression subpackage ([#2442](https://github.com/Lightning-AI/torchmetrics/pull/2442))


- Added a new audio metric `NISQA` ([#2792](https://github.com/PyTorchLightning/metrics/pull/2792))


- Added `Dice` metric to segmentation metrics ([#2725](https://github.com/Lightning-AI/torchmetrics/pull/2725))


Expand Down
21 changes: 21 additions & 0 deletions docs/source/audio/non_intrusive_speech_quality_assessment.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
.. customcarditem::
:header: Non-Intrusive Speech Quality Assessment (NISQA v2.0)
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/audio_classification.svg
:tags: Audio

.. include:: ../links.rst

####################################################
Non-Intrusive Speech Quality Assessment (NISQA v2.0)
####################################################

Module Interface
________________

.. autoclass:: torchmetrics.audio.nisqa.NonIntrusiveSpeechQualityAssessment
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.audio.nisqa.non_intrusive_speech_quality_assessment
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@
.. _Perceptual Evaluation of Speech Quality: https://en.wikipedia.org/wiki/Perceptual_Evaluation_of_Speech_Quality
.. _pesq package: https://github.com/ludlows/python-pesq
.. _Deep Noise Suppression performance evaluation based on Mean Opinion Score: https://arxiv.org/abs/2010.15258
.. _Non-Intrusive Speech Quality Assessment: https://arxiv.org/abs/2104.09494
.. _Cees Taal's website: http://www.ceestaal.nl/code/
.. _pystoi package: https://github.com/mpariente/pystoi
.. _stoi ref1: https://ieeexplore.ieee.org/abstract/document/5495701
Expand Down
6 changes: 6 additions & 0 deletions src/torchmetrics/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
_ONNXRUNTIME_AVAILABLE,
_PESQ_AVAILABLE,
_PYSTOI_AVAILABLE,
_REQUESTS_AVAILABLE,
_SCIPI_AVAILABLE,
_TORCHAUDIO_AVAILABLE,
)
Expand Down Expand Up @@ -68,3 +69,8 @@
from torchmetrics.audio.dnsmos import DeepNoiseSuppressionMeanOpinionScore

__all__ += ["DeepNoiseSuppressionMeanOpinionScore"]

if _LIBROSA_AVAILABLE and _REQUESTS_AVAILABLE:
from torchmetrics.audio.nisqa import NonIntrusiveSpeechQualityAssessment

__all__ += ["NonIntrusiveSpeechQualityAssessment"]
152 changes: 152 additions & 0 deletions src/torchmetrics/audio/nisqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Optional, Sequence, Union

from torch import Tensor, tensor

from torchmetrics.functional.audio.nisqa import non_intrusive_speech_quality_assessment
from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import (
_LIBROSA_AVAILABLE,
_MATPLOTLIB_AVAILABLE,
_REQUESTS_AVAILABLE,
)
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

__doctest_requires__ = {"NonIntrusiveSpeechQualityAssessment": ["librosa", "requests"]}

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["NonIntrusiveSpeechQualityAssessment.plot"]


class NonIntrusiveSpeechQualityAssessment(Metric):
"""`Non-Intrusive Speech Quality Assessment`_ (NISQA v2.0) [1], [2].

As input to ``forward`` and ``update`` the metric accepts the following input

- ``preds`` (:class:`~torch.Tensor`): float tensor with shape ``(...,time)``

As output of ``forward`` and ``compute`` the metric returns the following output

- ``nisqa`` (:class:`~torch.Tensor`): float tensor reduced across the batch with shape ``(5,)`` corresponding to
overall MOS, noisiness, discontinuity, coloration and loudness in that order

.. note:: Using this metric requires you to have ``librosa`` and ``requests`` installed. Install as
``pip install librosa requests``.

.. note:: The ``forward`` and ``compute`` methods in this class return values reduced across the batch. To obtain
values for each sample, you may use the functional counterpart
:func:`~torchmetrics.functional.audio.nisqa.non_intrusive_speech_quality_assessment`.

Args:
fs: sampling frequency of input

Raises:
ModuleNotFoundError:
If ``librosa`` or ``requests`` are not installed

Example:
>>> import torch
>>> from torchmetrics.audio import NonIntrusiveSpeechQualityAssessment
>>> _ = torch.manual_seed(42)
>>> preds = torch.randn(16000)
>>> nisqa = NonIntrusiveSpeechQualityAssessment(16000)
>>> nisqa(preds)
tensor([1.0433, 1.9545, 2.6087, 1.3460, 1.7117])

References:
- [1] G. Mittag and S. Möller, "Non-intrusive speech quality assessment for super-wideband speech communication
networks", in Proc. ICASSP, 2019.
- [2] G. Mittag, B. Naderi, A. Chehadi and S. Möller, "NISQA: A deep CNN-self-attention model for
multidimensional speech quality prediction with crowdsourced datasets", in Proc. INTERSPEECH, 2021.

"""

sum_nisqa: Tensor
total: Tensor
full_state_update: bool = False
is_differentiable: bool = False
higher_is_better: bool = True
plot_lower_bound: float = 0.0
plot_upper_bound: float = 5.0

def __init__(self, fs: int, **kwargs: Any) -> None:
super().__init__(**kwargs)
if not _LIBROSA_AVAILABLE or not _REQUESTS_AVAILABLE:
raise ModuleNotFoundError(
"NISQA metric requires that librosa and requests are installed. "
"Install as `pip install librosa requests`."
)
if not isinstance(fs, int) or fs <= 0:
raise ValueError(f"Argument `fs` expected to be a positive integer, but got {fs}")
self.fs = fs
philgzl marked this conversation as resolved.
Show resolved Hide resolved

self.add_state("sum_nisqa", default=tensor([0.0, 0.0, 0.0, 0.0, 0.0]), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")

def update(self, preds: Tensor) -> None:
"""Update state with predictions."""
nisqa_batch = non_intrusive_speech_quality_assessment(
preds,
self.fs,
).to(self.sum_nisqa.device)

nisqa_batch = nisqa_batch.reshape(-1, 5)
self.sum_nisqa += nisqa_batch.sum(dim=0)
self.total += nisqa_batch.shape[0]

def compute(self) -> Tensor:
"""Compute metric."""
return self.sum_nisqa / self.total

def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Args:
val: Either a single result from calling ``metric.forward`` or ``metric.compute`` or a list of these
results. If no value is provided, will automatically call ``metric.compute`` and plot that result.
ax: A matplotlib axis object. If provided will add plot to that axis

Returns:
Figure and Axes object

Raises:
ModuleNotFoundError:
If ``matplotlib`` is not installed

.. plot::
:scale: 75

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.audio import NonIntrusiveSpeechQualityAssessment
>>> metric = NonIntrusiveSpeechQualityAssessment(16000)
>>> metric.update(torch.randn(16000))
>>> fig_, ax_ = metric.plot()

.. plot::
:scale: 75

>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.audio import NonIntrusiveSpeechQualityAssessment
>>> metric = NonIntrusiveSpeechQualityAssessment(16000)
>>> values = []
>>> for _ in range(10):
... values.append(metric(torch.randn(16000)))
>>> fig_, ax_ = metric.plot(values)

"""
return self._plot(val, ax)
6 changes: 6 additions & 0 deletions src/torchmetrics/functional/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
_ONNXRUNTIME_AVAILABLE,
_PESQ_AVAILABLE,
_PYSTOI_AVAILABLE,
_REQUESTS_AVAILABLE,
_SCIPI_AVAILABLE,
_TORCHAUDIO_AVAILABLE,
)
Expand Down Expand Up @@ -69,3 +70,8 @@
from torchmetrics.functional.audio.dnsmos import deep_noise_suppression_mean_opinion_score

__all__ += ["deep_noise_suppression_mean_opinion_score"]

if _LIBROSA_AVAILABLE and _REQUESTS_AVAILABLE:
from torchmetrics.functional.audio.nisqa import non_intrusive_speech_quality_assessment

__all__ += ["non_intrusive_speech_quality_assessment"]
Loading
Loading