Skip to content

Commit

Permalink
Refactor: SNR & SI_SNR (#712)
Browse files Browse the repository at this point in the history
* signal_noise_ratio
* scale_invariant_signal_noise_ratio
* SignalNoiseRatio
* ScaleInvariantSignalNoiseRatio

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Borda and pre-commit-ci[bot] committed Jan 8, 2022
1 parent fdf5b3f commit d11641c
Show file tree
Hide file tree
Showing 14 changed files with 257 additions and 134 deletions.
9 changes: 8 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `MinMaxMetric` to wrappers ([#556](https://github.com/PyTorchLightning/metrics/pull/556))


- Added `ignore_index` to to retrieval metrics ([#676](https://github.com/PyTorchLightning/metrics/pull/676))
- Added `ignore_index` to retrieval metrics ([#676](https://github.com/PyTorchLightning/metrics/pull/676))


- Added support for multi references in `ROUGEScore` ([#680](https://github.com/PyTorchLightning/metrics/pull/680))
Expand Down Expand Up @@ -71,6 +71,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* `SI_SDR` -> `ScaleInvariantSignalDistortionRatio`


- Renamed audio SNR metrics: ([#712](https://github.com/PyTorchLightning/metrics/pull/712))
* `functional.snr` -> `functional.signal_distortion_ratio`
* `functional.si_snr` -> `functional.scale_invariant_signal_noise_ratio`
* `SNR` -> `SignalNoiseRatio`
* `SI_SNR` -> `ScaleInvariantSignalNoiseRatio`


### Removed

- Removed `embedding_similarity` metric ([#638](https://github.com/PyTorchLightning/metrics/pull/638))
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,8 @@ We currently have implemented metrics within the following domains:

- Audio (
[ScaleInvariantSignalDistortionRatio](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#ScaleInvariantSignalDistortionRatio),
[SI_SNR](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#si-snr),
[SNR](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#snr)
[ScaleInvariantSignalNoiseRatio](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#ScaleInvariantSignalNoiseRatio),
[SignalNoiseRatio](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#SignalNoiseRatio)
and [few more](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#audio-metrics)
)
- Classification (
Expand Down
10 changes: 5 additions & 5 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,17 @@ scale_invariant_signal_distortion_ratio [func]
:noindex:


si_snr [func]
~~~~~~~~~~~~~
scale_invariant_signal_noise_ratio [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.si_snr
:noindex:


snr [func]
~~~~~~~~~~
signal_noise_ratio [func]
~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.snr
.. autofunction:: torchmetrics.functional.signal_noise_ratio
:noindex:


Expand Down
23 changes: 14 additions & 9 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,17 @@ the metric will be computed over the ``time`` dimension.
.. doctest::

>>> import torch
>>> from torchmetrics import SignalNoiseRatio
>>> target = torch.tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0])
>>> snr = SignalNoiseRatio()
>>> snr(preds, target)
tensor(16.1805)
>>> from torchmetrics import SNR
>>> target = torch.tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0])
>>> snr = SNR()
>>> snr_val = snr(preds, target)
>>> snr_val
>>> snr = SignalNoiseRatio()
>>> snr(preds, target)
tensor(16.1805)

PESQ
Expand All @@ -97,16 +102,16 @@ ScaleInvariantSignalDistortionRatio
.. autoclass:: torchmetrics.ScaleInvariantSignalDistortionRatio
:noindex:

SI_SNR
~~~~~~
ScaleInvariantSignalNoiseRatio
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.SI_SNR
.. autoclass:: torchmetrics.ScaleInvariantSignalNoiseRatio
:noindex:

SNR
~~~
SignalNoiseRatio
~~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.SNR
.. autoclass:: torchmetrics.SignalNoiseRatio
:noindex:

STOI
Expand Down
24 changes: 17 additions & 7 deletions tests/audio/test_si_snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

from tests.helpers import seed_all
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from torchmetrics.audio import SI_SNR
from torchmetrics.functional import si_snr
from torchmetrics.audio import ScaleInvariantSignalNoiseRatio
from torchmetrics.functional import scale_invariant_signal_noise_ratio
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6

seed_all(42)
Expand Down Expand Up @@ -79,7 +79,7 @@ def test_si_snr(self, preds, target, sk_metric, ddp, dist_sync_on_step):
ddp,
preds,
target,
SI_SNR,
ScaleInvariantSignalNoiseRatio,
sk_metric=partial(average_metric, metric_func=sk_metric),
dist_sync_on_step=dist_sync_on_step,
)
Expand All @@ -88,12 +88,17 @@ def test_si_snr_functional(self, preds, target, sk_metric):
self.run_functional_metric_test(
preds,
target,
si_snr,
scale_invariant_signal_noise_ratio,
sk_metric,
)

def test_si_snr_differentiability(self, preds, target, sk_metric):
self.run_differentiability_test(preds=preds, target=target, metric_module=SI_SNR, metric_functional=si_snr)
self.run_differentiability_test(
preds=preds,
target=target,
metric_module=ScaleInvariantSignalNoiseRatio,
metric_functional=scale_invariant_signal_noise_ratio,
)

@pytest.mark.skipif(
not _TORCH_GREATER_EQUAL_1_6, reason="half support of core operations on not support before pytorch v1.6"
Expand All @@ -103,10 +108,15 @@ def test_si_snr_half_cpu(self, preds, target, sk_metric):

@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
def test_si_snr_half_gpu(self, preds, target, sk_metric):
self.run_precision_test_gpu(preds=preds, target=target, metric_module=SI_SNR, metric_functional=si_snr)
self.run_precision_test_gpu(
preds=preds,
target=target,
metric_module=ScaleInvariantSignalNoiseRatio,
metric_functional=scale_invariant_signal_noise_ratio,
)


def test_error_on_different_shape(metric_class=SI_SNR):
def test_error_on_different_shape(metric_class=ScaleInvariantSignalNoiseRatio):
metric = metric_class()
with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"):
metric(torch.randn(100), torch.randn(50))
28 changes: 18 additions & 10 deletions tests/audio/test_snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,19 @@

from tests.helpers import seed_all
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from torchmetrics.audio import SNR
from torchmetrics.functional import snr
from torchmetrics.audio import SignalNoiseRatio
from torchmetrics.functional import signal_noise_ratio
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6

seed_all(42)

Time = 100
TIME = 100

Input = namedtuple("Input", ["preds", "target"])

inputs = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, Time),
target=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, Time),
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, TIME),
target=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, TIME),
)


Expand Down Expand Up @@ -86,7 +86,7 @@ def test_snr(self, preds, target, sk_metric, zero_mean, ddp, dist_sync_on_step):
ddp,
preds,
target,
SNR,
SignalNoiseRatio,
sk_metric=partial(average_metric, metric_func=sk_metric),
dist_sync_on_step=dist_sync_on_step,
metric_args=dict(zero_mean=zero_mean),
Expand All @@ -96,14 +96,18 @@ def test_snr_functional(self, preds, target, sk_metric, zero_mean):
self.run_functional_metric_test(
preds,
target,
snr,
signal_noise_ratio,
sk_metric,
metric_args=dict(zero_mean=zero_mean),
)

def test_snr_differentiability(self, preds, target, sk_metric, zero_mean):
self.run_differentiability_test(
preds=preds, target=target, metric_module=SNR, metric_functional=snr, metric_args={"zero_mean": zero_mean}
preds=preds,
target=target,
metric_module=SignalNoiseRatio,
metric_functional=signal_noise_ratio,
metric_args={"zero_mean": zero_mean},
)

@pytest.mark.skipif(
Expand All @@ -115,11 +119,15 @@ def test_snr_half_cpu(self, preds, target, sk_metric, zero_mean):
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
def test_snr_half_gpu(self, preds, target, sk_metric, zero_mean):
self.run_precision_test_gpu(
preds=preds, target=target, metric_module=SNR, metric_functional=snr, metric_args={"zero_mean": zero_mean}
preds=preds,
target=target,
metric_module=SignalNoiseRatio,
metric_functional=signal_noise_ratio,
metric_args={"zero_mean": zero_mean},
)


def test_error_on_different_shape(metric_class=SNR):
def test_error_on_different_shape(metric_class=SignalNoiseRatio):
metric = metric_class()
with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"):
metric(torch.randn(100), torch.randn(50))
4 changes: 4 additions & 0 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
SI_SNR,
SNR,
ScaleInvariantSignalDistortionRatio,
ScaleInvariantSignalNoiseRatio,
SignalDistortionRatio,
SignalNoiseRatio,
)
from torchmetrics.classification import ( # noqa: E402, F401
AUC,
Expand Down Expand Up @@ -154,6 +156,8 @@
"ScaleInvariantSignalDistortionRatio",
"SI_SDR",
"SI_SNR",
"ScaleInvariantSignalNoiseRatio",
"SignalNoiseRatio",
"SNR",
"SpearmanCorrcoef",
"SpearmanCorrCoef",
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
from torchmetrics.audio.sdr import SDR, ScaleInvariantSignalDistortionRatio, SignalDistortionRatio # noqa: F401
from torchmetrics.audio.si_sdr import SI_SDR # noqa: F401
from torchmetrics.audio.si_snr import SI_SNR # noqa: F401
from torchmetrics.audio.snr import SNR # noqa: F401
from torchmetrics.audio.snr import SNR, ScaleInvariantSignalNoiseRatio, SignalNoiseRatio # noqa: F401
77 changes: 11 additions & 66 deletions torchmetrics/audio/si_snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,90 +12,35 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional
from warnings import warn

from torch import Tensor, tensor
from torch import Tensor

from torchmetrics.functional.audio.si_snr import si_snr
from torchmetrics.metric import Metric
from torchmetrics.audio.snr import ScaleInvariantSignalNoiseRatio


class SI_SNR(Metric):
class SI_SNR(ScaleInvariantSignalNoiseRatio):
"""Scale-invariant signal-to-noise ratio (SI-SNR).
Forward accepts
- ``preds``: ``shape [...,time]``
- ``target``: ``shape [...,time]``
Args:
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
process_group:
Specify the process group on which synchronization is called.
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When `None`, DDP
will be used to perform the allgather.
Raises:
TypeError:
if target and preds have a different shape
Returns:
average si-snr value
.. deprecated:: v0.7
Use :class:`torchmetrics.ScaleInvariantSignalNoiseRatio`. Will be removed in v0.8.
Example:
>>> import torch
>>> from torchmetrics import SI_SNR
>>> target = torch.tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0])
>>> si_snr = SI_SNR()
>>> si_snr_val = si_snr(preds, target)
>>> si_snr_val
>>> si_snr(torch.tensor([2.5, 0.0, 2.0, 8.0]), torch.tensor([3.0, -0.5, 2.0, 7.0]))
tensor(15.0918)
References:
[1] Y. Luo and N. Mesgarani, "TaSNet: Time-Domain Audio Separation Network for Real-Time, Single-Channel Speech
Separation," 2018 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), 2018, pp.
696-700, doi: 10.1109/ICASSP.2018.8462116.
"""

is_differentiable = True
sum_si_snr: Tensor
total: Tensor
higher_is_better = True

def __init__(
self,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
) -> None:
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
warn(
"`SI_SNR` was renamed to `ScaleInvariantSignalNoiseRatio` in v0.7 and it will be removed in v0.8",
DeprecationWarning,
)

self.add_state("sum_si_snr", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""Update state with predictions and targets.
Args:
preds: Predictions from model
target: Ground truth values
"""
si_snr_batch = si_snr(preds=preds, target=target)

self.sum_si_snr += si_snr_batch.sum()
self.total += si_snr_batch.numel()

def compute(self) -> Tensor:
"""Computes average SI-SNR."""
return self.sum_si_snr / self.total
super().__init__(compute_on_step, dist_sync_on_step, process_group, dist_sync_fn)
Loading

0 comments on commit d11641c

Please sign in to comment.