From 1ac0b63677dd720dd453772239bbc63fe1562d6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Changsheng=20Quan=20=28=E5=85=A8=E6=98=8C=E7=9B=9B=29?= Date: Wed, 17 Jul 2024 19:09:40 +0800 Subject: [PATCH] Add new Audio metric DNSMOS (#2525) * +DNSMOS * update * +test * Apply suggestions from code review --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: Jirka Co-authored-by: Nicki Skafte Detlefsen Co-authored-by: Bas Krahmer --- CHANGELOG.md | 3 + ...p_noise_suppression_mean_opinion_score.rst | 21 ++ docs/source/links.rst | 1 + requirements/audio.txt | 6 +- src/torchmetrics/audio/__init__.py | 7 + src/torchmetrics/audio/dnsmos.py | 173 +++++++++++ src/torchmetrics/functional/audio/__init__.py | 7 + src/torchmetrics/functional/audio/dnsmos.py | 278 ++++++++++++++++++ src/torchmetrics/utilities/imports.py | 3 + tests/unittests/audio/test_dnsmos.py | 252 ++++++++++++++++ 10 files changed, 750 insertions(+), 1 deletion(-) create mode 100644 docs/source/audio/deep_noise_suppression_mean_opinion_score.rst create mode 100644 src/torchmetrics/audio/dnsmos.py create mode 100644 src/torchmetrics/functional/audio/dnsmos.py create mode 100644 tests/unittests/audio/test_dnsmos.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 513628fa9f1..c09bdab17e0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added a new audio metric `DNSMOS` ([#2525](https://github.com/PyTorchLightning/metrics/pull/2525)) + + - Added `MetricInputTransformer` wrapper ([#2392](https://github.com/Lightning-AI/torchmetrics/pull/2392)) diff --git a/docs/source/audio/deep_noise_suppression_mean_opinion_score.rst b/docs/source/audio/deep_noise_suppression_mean_opinion_score.rst new file mode 100644 index 00000000000..371ad88c588 --- /dev/null +++ b/docs/source/audio/deep_noise_suppression_mean_opinion_score.rst @@ -0,0 +1,21 @@ +.. customcarditem:: + :header: Deep Noise Suppression Mean Opinion Score (DNSMOS) + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/audio_classification.svg + :tags: Audio + +.. include:: ../links.rst + +################################################## +Deep Noise Suppression Mean Opinion Score (DNSMOS) +################################################## + +Module Interface +________________ + +.. autoclass:: torchmetrics.audio.dnsmos.DeepNoiseSuppressionMeanOpinionScore + :exclude-members: update, compute + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.audio.dnsmos.deep_noise_suppression_mean_opinion_score diff --git a/docs/source/links.rst b/docs/source/links.rst index 0b0de53f325..e31b3362693 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -110,6 +110,7 @@ .. _Theils Uncertainty coefficient: https://en.wikipedia.org/wiki/Uncertainty_coefficient .. _Perceptual Evaluation of Speech Quality: https://en.wikipedia.org/wiki/Perceptual_Evaluation_of_Speech_Quality .. _pesq package: https://github.com/ludlows/python-pesq +.. _Deep Noise Suppression performance evaluation based on Mean Opinion Score: https://arxiv.org/abs/2010.15258 .. _Cees Taal's website: http://www.ceestaal.nl/code/ .. _pystoi package: https://github.com/mpariente/pystoi .. _stoi ref1: https://ieeexplore.ieee.org/abstract/document/5495701 diff --git a/requirements/audio.txt b/requirements/audio.txt index ba520163c29..ce717cab35a 100644 --- a/requirements/audio.txt +++ b/requirements/audio.txt @@ -5,4 +5,8 @@ pesq >=0.0.4, <0.0.5 pystoi >=0.3.0, <0.5.0 torchaudio >=0.10.0, <2.5.0 -gammatone >1.0.0, <1.1.0 +gammatone >=1.0.0, <1.1.0 +librosa >=0.9.0, <0.11.0 +onnxruntime-gpu >=1.12.0, <1.19; sys_platform != 'darwin' +onnxruntime >=1.12.0, <1.19; sys_platform == 'darwin' # installing onnxruntime-gpu failed on macos +requests >=2.19.0, <2.32.0 diff --git a/src/torchmetrics/audio/__init__.py b/src/torchmetrics/audio/__init__.py index 31c01171c01..6d21902b13e 100644 --- a/src/torchmetrics/audio/__init__.py +++ b/src/torchmetrics/audio/__init__.py @@ -24,6 +24,8 @@ ) from torchmetrics.utilities.imports import ( _GAMMATONE_AVAILABLE, + _LIBROSA_AVAILABLE, + _ONNXRUNTIME_AVAILABLE, _PESQ_AVAILABLE, _PYSTOI_AVAILABLE, _TORCHAUDIO_AVAILABLE, @@ -54,3 +56,8 @@ from torchmetrics.audio.srmr import SpeechReverberationModulationEnergyRatio __all__ += ["SpeechReverberationModulationEnergyRatio"] + +if _LIBROSA_AVAILABLE and _ONNXRUNTIME_AVAILABLE: + from torchmetrics.audio.dnsmos import DeepNoiseSuppressionMeanOpinionScore + + __all__ += ["DeepNoiseSuppressionMeanOpinionScore"] diff --git a/src/torchmetrics/audio/dnsmos.py b/src/torchmetrics/audio/dnsmos.py new file mode 100644 index 00000000000..a6d45aa11cb --- /dev/null +++ b/src/torchmetrics/audio/dnsmos.py @@ -0,0 +1,173 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional, Sequence, Union + +import torch +from torch import Tensor, tensor + +from torchmetrics.functional.audio.dnsmos import deep_noise_suppression_mean_opinion_score +from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import ( + _LIBROSA_AVAILABLE, + _MATPLOTLIB_AVAILABLE, + _ONNXRUNTIME_AVAILABLE, + _REQUESTS_AVAILABLE, +) +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +__doctest_requires__ = {"DeepNoiseSuppressionMeanOpinionScore": ["requests", "librosa", "onnxruntime"]} + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["DeepNoiseSuppressionMeanOpinionScore.plot"] + + +class DeepNoiseSuppressionMeanOpinionScore(Metric): + """Calculate `Deep Noise Suppression performance evaluation based on Mean Opinion Score`_ (DNSMOS). + + Human subjective evaluation is the ”gold standard” to evaluate speech quality optimized for human perception. + Perceptual objective metrics serve as a proxy for subjective scores. The conventional and widely used metrics + require a reference clean speech signal, which is unavailable in real recordings. The no-reference approaches + correlate poorly with human ratings and are not widely adopted in the research community. One of the biggest + use cases of these perceptual objective metrics is to evaluate noise suppression algorithms. DNSMOS generalizes + well in challenging test conditions with a high correlation to human ratings in stack ranking noise suppression + methods. More details can be found in `DNSMOS paper `_ and + `DNSMOS P.835 paper `_. + + + As input to ``forward`` and ``update`` the metric accepts the following input + + - ``preds`` (:class:`~torch.Tensor`): float tensor with shape ``(...,time)`` + + As output of ``forward`` and ``compute`` the metric returns the following output + + - ``dnsmos`` (:class:`~torch.Tensor`): float tensor of DNSMOS values reduced across the batch + with shape ``(...,4)`` indicating [p808_mos, mos_sig, mos_bak, mos_ovr] in the last dim. + + .. note:: using this metric requires you to have ``librosa``, ``onnxruntime`` and ``requests`` installed. + Install as ``pip install torchmetrics['audio']`` or alternatively `pip install librosa onnxruntime-gpu requests` + (if you do not have GPU enabled machine install `onnxruntime` instead of `onnxruntime-gpu`) + + .. note:: the ``forward`` and ``compute`` methods in this class return a reduced DNSMOS value + for a batch. To obtain the DNSMOS value for each sample, you may use the functional counterpart in + :func:`~torchmetrics.functional.audio.dnsmos.deep_noise_suppression_mean_opinion_score`. + + Args: + fs: sampling frequency + personalized: whether interfering speaker is penalized + device: the device used for calculating DNSMOS, can be cpu or cuda:n, where n is the index of gpu. + If None is given, then the device of input is used. + num_threads: number of threads to use for onnxruntime CPU inference. + + Raises: + ModuleNotFoundError: + If ``librosa``, ``onnxruntime`` or ``requests`` packages are not installed + + Example: + >>> from torch import randn + >>> from torchmetrics.audio import DeepNoiseSuppressionMeanOpinionScore + >>> g = torch.manual_seed(1) + >>> preds = randn(8000) + >>> dnsmos = DeepNoiseSuppressionMeanOpinionScore(8000, False) + >>> dnsmos(preds) + tensor([2.2285, 2.1132, 1.3972, 1.3652], dtype=torch.float64) + + """ + + sum_dnsmos: Tensor + total: Tensor + full_state_update: bool = False + is_differentiable: bool = False + higher_is_better: bool = True + plot_lower_bound: float = 0 + plot_upper_bound: float = 5 + + def __init__( + self, + fs: int, + personalized: bool, + device: Optional[str] = None, + num_threads: Optional[int] = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if not _LIBROSA_AVAILABLE or not _ONNXRUNTIME_AVAILABLE or not _REQUESTS_AVAILABLE: + raise ModuleNotFoundError( + "DNSMOS metric requires that librosa, onnxruntime and requests are installed." + " Install as `pip install librosa onnxruntime-gpu requests`." + ) + + self.fs = fs + self.personalized = personalized + self.cal_device = device + self.num_threads = num_threads + + self.add_state("sum_dnsmos", default=tensor([0, 0, 0, 0], dtype=torch.float64), dist_reduce_fx="sum") + self.add_state("total", default=tensor(0), dist_reduce_fx="sum") + + def update(self, preds: Tensor) -> None: + """Update state with predictions.""" + metric_batch = deep_noise_suppression_mean_opinion_score( + preds, + self.fs, + self.personalized, + self.cal_device, + self.num_threads, + ).to(self.sum_dnsmos.device) + + self.sum_dnsmos += metric_batch.reshape(-1, 4).sum(dim=0) + self.total += metric_batch.reshape(-1, 4).shape[0] + + def compute(self) -> Tensor: + """Compute metric.""" + return self.sum_dnsmos / self.total + + def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling ``metric.forward`` or ``metric.compute`` or a list of these + results. If no value is provided, will automatically call ``metric.compute`` and plot that result. + ax: A matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If ``matplotlib`` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics.audio import DeepNoiseSuppressionMeanOpinionScore + >>> metric = DeepNoiseSuppressionMeanOpinionScore(8000, False) + >>> metric.update(torch.rand(8000)) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.audio import DeepNoiseSuppressionMeanOpinionScore + >>> metric = DeepNoiseSuppressionMeanOpinionScore(8000, False) + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(torch.rand(8000))) + >>> fig_, ax_ = metric.plot(values) + + """ + return self._plot(val, ax) diff --git a/src/torchmetrics/functional/audio/__init__.py b/src/torchmetrics/functional/audio/__init__.py index 077442b0b83..c8a8b5a4bcc 100644 --- a/src/torchmetrics/functional/audio/__init__.py +++ b/src/torchmetrics/functional/audio/__init__.py @@ -24,6 +24,8 @@ ) from torchmetrics.utilities.imports import ( _GAMMATONE_AVAILABLE, + _LIBROSA_AVAILABLE, + _ONNXRUNTIME_AVAILABLE, _PESQ_AVAILABLE, _PYSTOI_AVAILABLE, _TORCHAUDIO_AVAILABLE, @@ -55,3 +57,8 @@ from torchmetrics.functional.audio.srmr import speech_reverberation_modulation_energy_ratio __all__ += ["speech_reverberation_modulation_energy_ratio"] + +if _LIBROSA_AVAILABLE and _ONNXRUNTIME_AVAILABLE: + from torchmetrics.functional.audio.dnsmos import deep_noise_suppression_mean_opinion_score + + __all__ += ["deep_noise_suppression_mean_opinion_score"] diff --git a/src/torchmetrics/functional/audio/dnsmos.py b/src/torchmetrics/functional/audio/dnsmos.py new file mode 100644 index 00000000000..ce29b8538c2 --- /dev/null +++ b/src/torchmetrics/functional/audio/dnsmos.py @@ -0,0 +1,278 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from functools import lru_cache +from typing import Any, Dict, Optional + +import numpy as np +import torch +from torch import Tensor + +from torchmetrics.utilities import rank_zero_info, rank_zero_warn +from torchmetrics.utilities.imports import _LIBROSA_AVAILABLE, _ONNXRUNTIME_AVAILABLE, _REQUESTS_AVAILABLE + +if _LIBROSA_AVAILABLE and _ONNXRUNTIME_AVAILABLE and _REQUESTS_AVAILABLE: + import librosa + import onnxruntime as ort + import requests + from onnxruntime import InferenceSession +else: + librosa, ort, requests = None, None, None # type:ignore + + class InferenceSession: # type:ignore + """Dummy InferenceSession.""" + + def __init__(self, **kwargs: Dict[str, Any]) -> None: ... + + +__doctest_requires__ = { + ("deep_noise_suppression_mean_opinion_score", "_load_session"): ["requests", "librosa", "onnxruntime"] +} + +SAMPLING_RATE = 16000 +INPUT_LENGTH = 9.01 +DNSMOS_DIR = "~/.torchmetrics/DNSMOS" + + +def _prepare_dnsmos(dnsmos_dir: str) -> None: + """Download required DNSMOS files. + + Args: + dnsmos_dir: a dir to save the downloaded files. Defaults to "~/.torchmetrics". + + """ + # https://raw.githubusercontent.com/microsoft/DNS-Challenge/master/DNSMOS/DNSMOS/model_v8.onnx + # https://raw.githubusercontent.com/microsoft/DNS-Challenge/master/DNSMOS/DNSMOS/sig_bak_ovr.onnx + # https://raw.githubusercontent.com/microsoft/DNS-Challenge/master/DNSMOS/pDNSMOS/sig_bak_ovr.onnx + url = "https://raw.githubusercontent.com/microsoft/DNS-Challenge/master" + dnsmos_dir = os.path.expanduser(dnsmos_dir) + + # save to or load from ~/torchmetrics/dnsmos/. + for file in ["DNSMOS/DNSMOS/model_v8.onnx", "DNSMOS/DNSMOS/sig_bak_ovr.onnx", "DNSMOS/pDNSMOS/sig_bak_ovr.onnx"]: + saveto = os.path.join(dnsmos_dir, file[7:]) + os.makedirs(os.path.dirname(saveto), exist_ok=True) + if os.path.exists(saveto): + # try loading onnx + try: + _ = InferenceSession(saveto) + continue # skip downloading if succeeded + except Exception as _: + os.remove(saveto) + urlf = f"{url}/{file}" + rank_zero_info(f"downloading {urlf} to {saveto}") + myfile = requests.get(urlf) + with open(saveto, "wb") as f: + f.write(myfile.content) + + +@lru_cache +def _load_session( + path: str, + device: torch.device, + num_threads: Optional[int] = None, +) -> InferenceSession: + """Load onnxruntime session. + + Args: + path: the model path + device: the device used + num_threads: the number of threads to use. Defaults to None. + + Returns: + onnxruntime session + + """ + path = os.path.expanduser(path) + if not os.path.exists(path): + _prepare_dnsmos(DNSMOS_DIR) + + opts = ort.SessionOptions() + if num_threads is not None: + opts.inter_op_num_threads = num_threads + opts.intra_op_num_threads = num_threads + + if device.type == "cpu": + infs = InferenceSession(path, providers=["CPUExecutionProvider"], sess_options=opts) + elif "CUDAExecutionProvider" in ort.get_available_providers(): # win or linux with cuda + providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] + provider_options = [{"device_id": device.index}, {}] + infs = InferenceSession(path, providers=providers, provider_options=provider_options, sess_options=opts) + elif "CoreMLExecutionProvider" in ort.get_available_providers(): # macos with coreml + providers = ["CoreMLExecutionProvider", "CPUExecutionProvider"] + provider_options = [{"device_id": device.index}, {}] + infs = InferenceSession(path, providers=providers, provider_options=provider_options, sess_options=opts) + else: + infs = InferenceSession(path, providers=["CPUExecutionProvider"], sess_options=opts) + + return infs + + +def _audio_melspec( + audio: np.ndarray, + n_mels: int = 120, + frame_size: int = 320, + hop_length: int = 160, + sr: int = 16000, + to_db: bool = True, +) -> np.ndarray: + """Calculate the mel-spectrogram of an audio. + + Args: + audio: [..., T] + n_mels: the number of mel-frequencies + frame_size: stft length + hop_length: stft hop length + sr: sample rate of audio + to_db: convert to dB scale if `True` is given + + Returns: + mel-spectrogram: [..., num_mel, T'] + + """ + shape = audio.shape + audio = audio.reshape(-1, shape[-1]) + mel_spec = librosa.feature.melspectrogram( + y=audio, sr=sr, n_fft=frame_size + 1, hop_length=hop_length, n_mels=n_mels + ) + mel_spec = mel_spec.transpose(0, 2, 1) + mel_spec = mel_spec.reshape(shape[:-1] + mel_spec.shape[1:]) + if to_db: + for b in range(mel_spec.shape[0]): + mel_spec[b, ...] = (librosa.power_to_db(mel_spec[b], ref=np.max) + 40) / 40 + return mel_spec + + +def _polyfit_val(mos: np.ndarray, personalized: bool) -> np.ndarray: + """Use polyfit to convert raw mos values to DNSMOS values. + + Args: + mos: the raw mos values, [..., 4] + personalized: whether interfering speaker is penalized + + Returns: + DNSMOS: [..., 4] + + """ + if personalized: + p_ovr = np.poly1d([-0.00533021, 0.005101, 1.18058466, -0.11236046]) + p_sig = np.poly1d([-0.01019296, 0.02751166, 1.19576786, -0.24348726]) + p_bak = np.poly1d([-0.04976499, 0.44276479, -0.1644611, 0.96883132]) + else: + p_ovr = np.poly1d([-0.06766283, 1.11546468, 0.04602535]) + p_sig = np.poly1d([-0.08397278, 1.22083953, 0.0052439]) # x**2*v0 + x**1*v1+ v2 + p_bak = np.poly1d([-0.13166888, 1.60915514, -0.39604546]) + + mos[..., 1] = p_sig(mos[..., 1]) + mos[..., 2] = p_bak(mos[..., 2]) + mos[..., 3] = p_ovr(mos[..., 3]) + return mos + + +def deep_noise_suppression_mean_opinion_score( + preds: Tensor, fs: int, personalized: bool, device: Optional[str] = None, num_threads: Optional[int] = None +) -> Tensor: + """Calculate `Deep Noise Suppression performance evaluation based on Mean Opinion Score`_ (DNSMOS). + + Human subjective evaluation is the ”gold standard” to evaluate speech quality optimized for human perception. + Perceptual objective metrics serve as a proxy for subjective scores. The conventional and widely used metrics + require a reference clean speech signal, which is unavailable in real recordings. The no-reference approaches + correlate poorly with human ratings and are not widely adopted in the research community. One of the biggest + use cases of these perceptual objective metrics is to evaluate noise suppression algorithms. DNSMOS generalizes + well in challenging test conditions with a high correlation to human ratings in stack ranking noise suppression + methods. More details can be found in `DNSMOS paper `_ and + `DNSMOS P.835 paper `_. + + + .. note:: using this metric requires you to have ``librosa``, ``onnxruntime`` and ``requests`` installed. Install + as ``pip install torchmetrics['audio']`` or alternatively ``pip install librosa onnxruntime-gpu requests`` + (if you do not have GPU enabled machine install ``onnxruntime`` instead of ``onnxruntime-gpu``) + + Args: + preds: [..., time] + fs: sampling frequency + personalized: whether interfering speaker is penalized + device: the device used for calculating DNSMOS, can be cpu or cuda:n, where n is the index of gpu. + If None is given, then the device of input is used. + num_threads: the number of threads to use for cpu inference. Defaults to None. + + Returns: + Float tensor with shape ``(...,4)`` of DNSMOS values per sample, i.e. [p808_mos, mos_sig, mos_bak, mos_ovr] + + Raises: + ModuleNotFoundError: + If ``librosa``, ``onnxruntime`` or ``requests`` packages are not installed + + Example: + >>> from torch import randn + >>> from torchmetrics.functional.audio.dnsmos import deep_noise_suppression_mean_opinion_score + >>> g = torch.manual_seed(1) + >>> preds = randn(8000) + >>> deep_noise_suppression_mean_opinion_score(preds, 8000, False) + tensor([2.2285, 2.1132, 1.3972, 1.3652], dtype=torch.float64) + + """ + if not _LIBROSA_AVAILABLE or not _ONNXRUNTIME_AVAILABLE or not _REQUESTS_AVAILABLE: + raise ModuleNotFoundError( + "DNSMOS metric requires that librosa, onnxruntime and requests are installed." + " Install as `pip install librosa onnxruntime-gpu requests`." + ) + device = torch.device(device) if device is not None else preds.device + + onnx_sess = _load_session(f"{DNSMOS_DIR}/{'p' if personalized else ''}DNSMOS/sig_bak_ovr.onnx", device, num_threads) + p808_onnx_sess = _load_session(f"{DNSMOS_DIR}/DNSMOS/model_v8.onnx", device, num_threads) + + desired_fs = SAMPLING_RATE + if fs != desired_fs: + audio = librosa.resample(preds.cpu().numpy(), orig_sr=fs, target_sr=desired_fs) + else: + audio = preds.cpu().numpy() + + len_samples = int(INPUT_LENGTH * desired_fs) + while audio.shape[-1] < len_samples: + audio = np.concatenate([audio, audio], axis=-1) + + num_hops = int(np.floor(audio.shape[-1] / desired_fs) - INPUT_LENGTH) + 1 + + moss = [] + hop_len_samples = desired_fs + for idx in range(num_hops): + audio_seg = audio[..., int(idx * hop_len_samples) : int((idx + INPUT_LENGTH) * hop_len_samples)] + if audio_seg.shape[-1] < len_samples: + continue + shape = audio_seg.shape + audio_seg = audio_seg.reshape((-1, shape[-1])) + + input_features = np.array(audio_seg).astype("float32") + p808_input_features = np.array(_audio_melspec(audio=audio_seg[..., :-160])).astype("float32") + + if device.type != "cpu" and ( + "CUDAExecutionProvider" in ort.get_available_providers() + or "CoreMLExecutionProvider" in ort.get_available_providers() + ): + try: + input_features = ort.OrtValue.ortvalue_from_numpy(input_features, device.type, device.index) + p808_input_features = ort.OrtValue.ortvalue_from_numpy(p808_input_features, device.type, device.index) + except Exception as e: + rank_zero_warn(f"Failed to use GPU for DNSMOS, reverting to CPU. Error: {e}") + + oi = {"input_1": input_features} + p808_oi = {"input_1": p808_input_features} + mos_np = np.concatenate( + [p808_onnx_sess.run(None, p808_oi)[0], onnx_sess.run(None, oi)[0]], axis=-1, dtype="float64" + ) + mos_np = _polyfit_val(mos_np, personalized) + + mos_np = mos_np.reshape(shape[:-1] + (4,)) + moss.append(mos_np) + return torch.from_numpy(np.mean(np.stack(moss, axis=-1), axis=-1)) # [p808_mos, mos_sig, mos_bak, mos_ovr] diff --git a/src/torchmetrics/utilities/imports.py b/src/torchmetrics/utilities/imports.py index f21f29f8152..68692683351 100644 --- a/src/torchmetrics/utilities/imports.py +++ b/src/torchmetrics/utilities/imports.py @@ -48,6 +48,9 @@ _TORCHAUDIO_GREATER_EQUAL_0_10 = RequirementCache("torchaudio>=0.10.0") _REGEX_AVAILABLE = RequirementCache("regex") _PYSTOI_AVAILABLE = RequirementCache("pystoi") +_REQUESTS_AVAILABLE = RequirementCache("requests") +_LIBROSA_AVAILABLE = RequirementCache("librosa") +_ONNXRUNTIME_AVAILABLE = RequirementCache("onnxruntime") _FAST_BSS_EVAL_AVAILABLE = RequirementCache("fast_bss_eval") _MATPLOTLIB_AVAILABLE = RequirementCache("matplotlib") _SCIENCEPLOT_AVAILABLE = RequirementCache("scienceplots") diff --git a/tests/unittests/audio/test_dnsmos.py b/tests/unittests/audio/test_dnsmos.py new file mode 100644 index 00000000000..c3c1df7a03b --- /dev/null +++ b/tests/unittests/audio/test_dnsmos.py @@ -0,0 +1,252 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from functools import partial +from typing import Any, Dict, Optional + +import numpy as np +import pytest +import torch +from torch import Tensor +from torchmetrics.audio.dnsmos import DeepNoiseSuppressionMeanOpinionScore +from torchmetrics.functional.audio.dnsmos import ( + DNSMOS_DIR, + _load_session, + deep_noise_suppression_mean_opinion_score, +) +from torchmetrics.utilities.imports import ( + _LIBROSA_AVAILABLE, + _ONNXRUNTIME_AVAILABLE, + _REQUESTS_AVAILABLE, +) + +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester + +if _LIBROSA_AVAILABLE and _ONNXRUNTIME_AVAILABLE and _REQUESTS_AVAILABLE: + import librosa + import onnxruntime as ort +else: + librosa, ort = None, None # type:ignore + + class InferenceSession: # type:ignore + """Dummy InferenceSession.""" + + def __init__(self, **kwargs: Dict[str, Any]) -> None: ... + + +SAMPLING_RATE = 16000 +INPUT_LENGTH = 9.01 +seed_all(42) + + +class _ComputeScore: + """The implementation from DNS-Challenge.""" + + def __init__(self, primary_model_path, p808_model_path) -> None: + self.onnx_sess = ort.InferenceSession(os.path.expanduser(primary_model_path)) + self.p808_onnx_sess = ort.InferenceSession(os.path.expanduser(p808_model_path)) + + def _audio_melspec(self, audio, n_mels=120, frame_size=320, hop_length=160, sr=16000, to_db=True): + mel_spec = librosa.feature.melspectrogram( + y=audio, sr=sr, n_fft=frame_size + 1, hop_length=hop_length, n_mels=n_mels + ) + if to_db: + mel_spec = (librosa.power_to_db(mel_spec, ref=np.max) + 40) / 40 + return mel_spec.T + + def _get_polyfit_val(self, sig, bak, ovr, is_personalized): + if is_personalized: + p_ovr = np.poly1d([-0.00533021, 0.005101, 1.18058466, -0.11236046]) + p_sig = np.poly1d([-0.01019296, 0.02751166, 1.19576786, -0.24348726]) + p_bak = np.poly1d([-0.04976499, 0.44276479, -0.1644611, 0.96883132]) + else: + p_ovr = np.poly1d([-0.06766283, 1.11546468, 0.04602535]) + p_sig = np.poly1d([-0.08397278, 1.22083953, 0.0052439]) + p_bak = np.poly1d([-0.13166888, 1.60915514, -0.39604546]) + + sig_poly = p_sig(sig) + bak_poly = p_bak(bak) + ovr_poly = p_ovr(ovr) + + return sig_poly, bak_poly, ovr_poly + + def __call__(self, aud, input_fs, is_personalized) -> Dict[str, Any]: + fs = SAMPLING_RATE + audio = librosa.resample(aud, orig_sr=input_fs, target_sr=fs) if input_fs != fs else aud + actual_audio_len = len(audio) + len_samples = int(INPUT_LENGTH * fs) + while len(audio) < len_samples: + audio = np.append(audio, audio) + + num_hops = int(np.floor(len(audio) / fs) - INPUT_LENGTH) + 1 + hop_len_samples = fs + predicted_mos_sig_seg_raw = [] + predicted_mos_bak_seg_raw = [] + predicted_mos_ovr_seg_raw = [] + predicted_mos_sig_seg = [] + predicted_mos_bak_seg = [] + predicted_mos_ovr_seg = [] + predicted_p808_mos = [] + + for idx in range(num_hops): + audio_seg = audio[int(idx * hop_len_samples) : int((idx + INPUT_LENGTH) * hop_len_samples)] + if len(audio_seg) < len_samples: + continue + + input_features = np.array(audio_seg).astype("float32")[np.newaxis, :] + p808_input_features = np.array(self._audio_melspec(audio=audio_seg[:-160])).astype("float32")[ + np.newaxis, :, : + ] + oi = {"input_1": input_features} + p808_oi = {"input_1": p808_input_features} + p808_mos = self.p808_onnx_sess.run(None, p808_oi)[0][0][0] + mos_sig_raw, mos_bak_raw, mos_ovr_raw = self.onnx_sess.run(None, oi)[0][0] + mos_sig, mos_bak, mos_ovr = self._get_polyfit_val(mos_sig_raw, mos_bak_raw, mos_ovr_raw, is_personalized) + predicted_mos_sig_seg_raw.append(mos_sig_raw) + predicted_mos_bak_seg_raw.append(mos_bak_raw) + predicted_mos_ovr_seg_raw.append(mos_ovr_raw) + predicted_mos_sig_seg.append(mos_sig) + predicted_mos_bak_seg.append(mos_bak) + predicted_mos_ovr_seg.append(mos_ovr) + predicted_p808_mos.append(p808_mos) + + return { + "len_in_sec": actual_audio_len / fs, + "sr": fs, + "num_hops": num_hops, + "OVRL_raw": np.mean(predicted_mos_ovr_seg_raw), + "SIG_raw": np.mean(predicted_mos_sig_seg_raw), + "BAK_raw": np.mean(predicted_mos_bak_seg_raw), + "OVRL": np.mean(predicted_mos_ovr_seg), + "SIG": np.mean(predicted_mos_sig_seg), + "BAK": np.mean(predicted_mos_bak_seg), + "P808_MOS": np.mean(predicted_p808_mos), + } + + +def _reference_metric_batch( + preds: Tensor, # shape:[BATCH_SIZE, Time] + target: Tensor, # for tester + fs: int, + personalized: bool, + device: Optional[str] = None, # for tester + reduce_mean: bool = False, + **kwargs: Dict[str, Any], # for tester +): + # download onnx first + _load_session(f"{DNSMOS_DIR}/{'p' if personalized else ''}DNSMOS/sig_bak_ovr.onnx", torch.device("cpu")) + _load_session(f"{DNSMOS_DIR}/DNSMOS/model_v8.onnx", torch.device("cpu")) + # construct ComputeScore + cs = _ComputeScore( + f"{DNSMOS_DIR}/{'p' if personalized else ''}DNSMOS/sig_bak_ovr.onnx", + f"{DNSMOS_DIR}/DNSMOS/model_v8.onnx", + ) + + shape = preds.shape + preds = preds.reshape(1, -1) if len(shape) == 1 else preds.reshape(-1, shape[-1]) + + preds = preds.detach().cpu().numpy() + score = [] + for b in range(preds.shape[0]): + val = cs.__call__(preds[b, ...], fs, personalized) + score.append([val["P808_MOS"], val["SIG"], val["BAK"], val["OVRL"]]) + score = torch.tensor(score) + if reduce_mean: + # shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time] + # or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time] + return score.mean(dim=0) + return score.reshape(*shape[:-1], 4).reshape(shape[:-1] + (4,)).numpy() + + +def _dnsmos_cheat(preds, target, **kwargs: Dict[str, Any]): + # cheat the MetricTester as the deep_noise_suppression_mean_opinion_score doesn't need target + return deep_noise_suppression_mean_opinion_score(preds, **kwargs) + + +class _DNSMOSCheat(DeepNoiseSuppressionMeanOpinionScore): + # cheat the MetricTester as DeepNoiseSuppressionMeanOpinionScore doesn't need target + def update(self, preds: Tensor, target: Tensor) -> None: + super().update(preds=preds) + + +preds = torch.rand(2, 2, 8000) + + +@pytest.mark.parametrize( + "preds, fs, personalized", + [ + (preds, 8000, False), + (preds, 8000, True), + (preds, 16000, False), + (preds, 16000, True), + ], +) +class TestDNSMOS(MetricTester): + """Test class for `DeepNoiseSuppressionMeanOpinionScore` metric.""" + + atol = 5e-3 + + @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) + def test_dnsmos(self, preds: Tensor, fs: int, personalized: bool, ddp: bool, device=None): + """Test class implementation of metric.""" + self.run_class_metric_test( + ddp, + preds=preds, + target=preds, + metric_class=_DNSMOSCheat, + reference_metric=partial( + _reference_metric_batch, + fs=fs, + personalized=personalized, + device=device, + reduce_mean=True, + ), + metric_args={"fs": fs, "personalized": personalized, "device": device}, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) + def test_dnsmos_cuda(self, preds: Tensor, fs: int, personalized: bool, ddp: bool, device="cuda:0"): + """Test class implementation of metric.""" + self.run_class_metric_test( + ddp, + preds=preds, + target=preds, + metric_class=_DNSMOSCheat, + reference_metric=partial( + _reference_metric_batch, + fs=fs, + personalized=personalized, + device=device, + reduce_mean=True, + ), + metric_args={"fs": fs, "personalized": personalized, "device": device}, + ) + + def test_dnsmos_functional(self, preds: Tensor, fs: int, personalized: bool, device="cpu"): + """Test functional implementation of metric.""" + self.run_functional_metric_test( + preds=preds, + target=preds, + metric_functional=_dnsmos_cheat, + reference_metric=partial( + _reference_metric_batch, + fs=fs, + personalized=personalized, + device=device, + reduce_mean=False, + ), + metric_args={"fs": fs, "personalized": personalized, "device": device}, + )