Skip to content

Commit b7518dd

Browse files
Refactor: SNR & SI_SNR (#712)
* signal_noise_ratio * scale_invariant_signal_noise_ratio * SignalNoiseRatio Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent fdf5b3f commit b7518dd

File tree

14 files changed

+257
-134
lines changed

14 files changed

+257
-134
lines changed

CHANGELOG.md

+8-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2828
- Added `MinMaxMetric` to wrappers ([#556](https://github.com/PyTorchLightning/metrics/pull/556))
2929

3030

31-
- Added `ignore_index` to to retrieval metrics ([#676](https://github.com/PyTorchLightning/metrics/pull/676))
31+
- Added `ignore_index` to retrieval metrics ([#676](https://github.com/PyTorchLightning/metrics/pull/676))
3232

3333

3434
- Added support for multi references in `ROUGEScore` ([#680](https://github.com/PyTorchLightning/metrics/pull/680))
@@ -71,6 +71,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
7171
* `SI_SDR` -> `ScaleInvariantSignalDistortionRatio`
7272

7373

74+
- Renamed audio SNR metrics: ([#712](https://github.com/PyTorchLightning/metrics/pull/712))
75+
* `functional.snr` -> `functional.signal_distortion_ratio`
76+
* `functional.si_snr` -> `functional.scale_invariant_signal_noise_ratio`
77+
* `SNR` -> `SignalNoiseRatio`
78+
* `SI_SNR` -> `ScaleInvariantSignalNoiseRatio`
79+
80+
7481
### Removed
7582

7683
- Removed `embedding_similarity` metric ([#638](https://github.com/PyTorchLightning/metrics/pull/638))

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -267,8 +267,8 @@ We currently have implemented metrics within the following domains:
267267

268268
- Audio (
269269
[ScaleInvariantSignalDistortionRatio](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#ScaleInvariantSignalDistortionRatio),
270-
[SI_SNR](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#si-snr),
271-
[SNR](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#snr)
270+
[ScaleInvariantSignalNoiseRatio](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#ScaleInvariantSignalNoiseRatio),
271+
[SignalNoiseRatio](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#SignalNoiseRatio)
272272
and [few more](https://torchmetrics.readthedocs.io/en/latest/references/modules.html#audio-metrics)
273273
)
274274
- Classification (

docs/source/references/functional.rst

+5-5
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,17 @@ scale_invariant_signal_distortion_ratio [func]
3838
:noindex:
3939

4040

41-
si_snr [func]
42-
~~~~~~~~~~~~~
41+
scale_invariant_signal_noise_ratio [func]
42+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4343

4444
.. autofunction:: torchmetrics.functional.si_snr
4545
:noindex:
4646

4747

48-
snr [func]
49-
~~~~~~~~~~
48+
signal_noise_ratio [func]
49+
~~~~~~~~~~~~~~~~~~~~~~~~~
5050

51-
.. autofunction:: torchmetrics.functional.snr
51+
.. autofunction:: torchmetrics.functional.signal_noise_ratio
5252
:noindex:
5353

5454

docs/source/references/modules.rst

+14-9
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,17 @@ the metric will be computed over the ``time`` dimension.
6666
.. doctest::
6767

6868
>>> import torch
69+
>>> from torchmetrics import SignalNoiseRatio
70+
>>> target = torch.tensor([3.0, -0.5, 2.0, 7.0])
71+
>>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0])
72+
>>> snr = SignalNoiseRatio()
73+
>>> snr(preds, target)
74+
tensor(16.1805)
6975
>>> from torchmetrics import SNR
7076
>>> target = torch.tensor([3.0, -0.5, 2.0, 7.0])
7177
>>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0])
72-
>>> snr = SNR()
73-
>>> snr_val = snr(preds, target)
74-
>>> snr_val
78+
>>> snr = SignalNoiseRatio()
79+
>>> snr(preds, target)
7580
tensor(16.1805)
7681

7782
PESQ
@@ -97,16 +102,16 @@ ScaleInvariantSignalDistortionRatio
97102
.. autoclass:: torchmetrics.ScaleInvariantSignalDistortionRatio
98103
:noindex:
99104

100-
SI_SNR
101-
~~~~~~
105+
ScaleInvariantSignalNoiseRatio
106+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
102107

103-
.. autoclass:: torchmetrics.SI_SNR
108+
.. autoclass:: torchmetrics.ScaleInvariantSignalNoiseRatio
104109
:noindex:
105110

106-
SNR
107-
~~~
111+
SignalNoiseRatio
112+
~~~~~~~~~~~~~~~~
108113

109-
.. autoclass:: torchmetrics.SNR
114+
.. autoclass:: torchmetrics.SignalNoiseRatio
110115
:noindex:
111116

112117
STOI

tests/audio/test_si_snr.py

+17-7
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121

2222
from tests.helpers import seed_all
2323
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
24-
from torchmetrics.audio import SI_SNR
25-
from torchmetrics.functional import si_snr
24+
from torchmetrics.audio import ScaleInvariantSignalNoiseRatio
25+
from torchmetrics.functional import scale_invariant_signal_noise_ratio
2626
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6
2727

2828
seed_all(42)
@@ -79,7 +79,7 @@ def test_si_snr(self, preds, target, sk_metric, ddp, dist_sync_on_step):
7979
ddp,
8080
preds,
8181
target,
82-
SI_SNR,
82+
ScaleInvariantSignalNoiseRatio,
8383
sk_metric=partial(average_metric, metric_func=sk_metric),
8484
dist_sync_on_step=dist_sync_on_step,
8585
)
@@ -88,12 +88,17 @@ def test_si_snr_functional(self, preds, target, sk_metric):
8888
self.run_functional_metric_test(
8989
preds,
9090
target,
91-
si_snr,
91+
scale_invariant_signal_noise_ratio,
9292
sk_metric,
9393
)
9494

9595
def test_si_snr_differentiability(self, preds, target, sk_metric):
96-
self.run_differentiability_test(preds=preds, target=target, metric_module=SI_SNR, metric_functional=si_snr)
96+
self.run_differentiability_test(
97+
preds=preds,
98+
target=target,
99+
metric_module=ScaleInvariantSignalNoiseRatio,
100+
metric_functional=scale_invariant_signal_noise_ratio,
101+
)
97102

98103
@pytest.mark.skipif(
99104
not _TORCH_GREATER_EQUAL_1_6, reason="half support of core operations on not support before pytorch v1.6"
@@ -103,10 +108,15 @@ def test_si_snr_half_cpu(self, preds, target, sk_metric):
103108

104109
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
105110
def test_si_snr_half_gpu(self, preds, target, sk_metric):
106-
self.run_precision_test_gpu(preds=preds, target=target, metric_module=SI_SNR, metric_functional=si_snr)
111+
self.run_precision_test_gpu(
112+
preds=preds,
113+
target=target,
114+
metric_module=ScaleInvariantSignalNoiseRatio,
115+
metric_functional=scale_invariant_signal_noise_ratio,
116+
)
107117

108118

109-
def test_error_on_different_shape(metric_class=SI_SNR):
119+
def test_error_on_different_shape(metric_class=ScaleInvariantSignalNoiseRatio):
110120
metric = metric_class()
111121
with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"):
112122
metric(torch.randn(100), torch.randn(50))

tests/audio/test_snr.py

+18-10
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,19 @@
2222

2323
from tests.helpers import seed_all
2424
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
25-
from torchmetrics.audio import SNR
26-
from torchmetrics.functional import snr
25+
from torchmetrics.audio import SignalNoiseRatio
26+
from torchmetrics.functional import signal_noise_ratio
2727
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6
2828

2929
seed_all(42)
3030

31-
Time = 100
31+
TIME = 100
3232

3333
Input = namedtuple("Input", ["preds", "target"])
3434

3535
inputs = Input(
36-
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, Time),
37-
target=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, Time),
36+
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, TIME),
37+
target=torch.rand(NUM_BATCHES, BATCH_SIZE, 1, TIME),
3838
)
3939

4040

@@ -86,7 +86,7 @@ def test_snr(self, preds, target, sk_metric, zero_mean, ddp, dist_sync_on_step):
8686
ddp,
8787
preds,
8888
target,
89-
SNR,
89+
SignalNoiseRatio,
9090
sk_metric=partial(average_metric, metric_func=sk_metric),
9191
dist_sync_on_step=dist_sync_on_step,
9292
metric_args=dict(zero_mean=zero_mean),
@@ -96,14 +96,18 @@ def test_snr_functional(self, preds, target, sk_metric, zero_mean):
9696
self.run_functional_metric_test(
9797
preds,
9898
target,
99-
snr,
99+
signal_noise_ratio,
100100
sk_metric,
101101
metric_args=dict(zero_mean=zero_mean),
102102
)
103103

104104
def test_snr_differentiability(self, preds, target, sk_metric, zero_mean):
105105
self.run_differentiability_test(
106-
preds=preds, target=target, metric_module=SNR, metric_functional=snr, metric_args={"zero_mean": zero_mean}
106+
preds=preds,
107+
target=target,
108+
metric_module=SignalNoiseRatio,
109+
metric_functional=signal_noise_ratio,
110+
metric_args={"zero_mean": zero_mean},
107111
)
108112

109113
@pytest.mark.skipif(
@@ -115,11 +119,15 @@ def test_snr_half_cpu(self, preds, target, sk_metric, zero_mean):
115119
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
116120
def test_snr_half_gpu(self, preds, target, sk_metric, zero_mean):
117121
self.run_precision_test_gpu(
118-
preds=preds, target=target, metric_module=SNR, metric_functional=snr, metric_args={"zero_mean": zero_mean}
122+
preds=preds,
123+
target=target,
124+
metric_module=SignalNoiseRatio,
125+
metric_functional=signal_noise_ratio,
126+
metric_args={"zero_mean": zero_mean},
119127
)
120128

121129

122-
def test_error_on_different_shape(metric_class=SNR):
130+
def test_error_on_different_shape(metric_class=SignalNoiseRatio):
123131
metric = metric_class()
124132
with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"):
125133
metric(torch.randn(100), torch.randn(50))

torchmetrics/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
SI_SNR,
2121
SNR,
2222
ScaleInvariantSignalDistortionRatio,
23+
ScaleInvariantSignalNoiseRatio,
2324
SignalDistortionRatio,
25+
SignalNoiseRatio,
2426
)
2527
from torchmetrics.classification import ( # noqa: E402, F401
2628
AUC,
@@ -154,6 +156,8 @@
154156
"ScaleInvariantSignalDistortionRatio",
155157
"SI_SDR",
156158
"SI_SNR",
159+
"ScaleInvariantSignalNoiseRatio",
160+
"SignalNoiseRatio",
157161
"SNR",
158162
"SpearmanCorrcoef",
159163
"SpearmanCorrCoef",

torchmetrics/audio/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,4 @@
1515
from torchmetrics.audio.sdr import SDR, ScaleInvariantSignalDistortionRatio, SignalDistortionRatio # noqa: F401
1616
from torchmetrics.audio.si_sdr import SI_SDR # noqa: F401
1717
from torchmetrics.audio.si_snr import SI_SNR # noqa: F401
18-
from torchmetrics.audio.snr import SNR # noqa: F401
18+
from torchmetrics.audio.snr import SNR, ScaleInvariantSignalNoiseRatio, SignalNoiseRatio # noqa: F401

torchmetrics/audio/si_snr.py

+11-66
Original file line numberDiff line numberDiff line change
@@ -12,90 +12,35 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from typing import Any, Callable, Optional
15+
from warnings import warn
1516

16-
from torch import Tensor, tensor
17+
from torch import Tensor
1718

18-
from torchmetrics.functional.audio.si_snr import si_snr
19-
from torchmetrics.metric import Metric
19+
from torchmetrics.audio.snr import ScaleInvariantSignalNoiseRatio
2020

2121

22-
class SI_SNR(Metric):
22+
class SI_SNR(ScaleInvariantSignalNoiseRatio):
2323
"""Scale-invariant signal-to-noise ratio (SI-SNR).
2424
25-
Forward accepts
26-
27-
- ``preds``: ``shape [...,time]``
28-
- ``target``: ``shape [...,time]``
29-
30-
Args:
31-
compute_on_step:
32-
Forward only calls ``update()`` and returns None if this is set to False.
33-
dist_sync_on_step:
34-
Synchronize metric state across processes at each ``forward()``
35-
before returning the value at the step.
36-
process_group:
37-
Specify the process group on which synchronization is called.
38-
dist_sync_fn:
39-
Callback that performs the allgather operation on the metric state. When `None`, DDP
40-
will be used to perform the allgather.
41-
42-
Raises:
43-
TypeError:
44-
if target and preds have a different shape
45-
46-
Returns:
47-
average si-snr value
25+
.. deprecated:: v0.7
26+
Use :class:`torchmetrics.ScaleInvariantSignalNoiseRatio`. Will be removed in v0.8.
4827
4928
Example:
5029
>>> import torch
51-
>>> from torchmetrics import SI_SNR
52-
>>> target = torch.tensor([3.0, -0.5, 2.0, 7.0])
53-
>>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0])
5430
>>> si_snr = SI_SNR()
55-
>>> si_snr_val = si_snr(preds, target)
56-
>>> si_snr_val
31+
>>> si_snr(torch.tensor([2.5, 0.0, 2.0, 8.0]), torch.tensor([3.0, -0.5, 2.0, 7.0]))
5732
tensor(15.0918)
58-
59-
References:
60-
[1] Y. Luo and N. Mesgarani, "TaSNet: Time-Domain Audio Separation Network for Real-Time, Single-Channel Speech
61-
Separation," 2018 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), 2018, pp.
62-
696-700, doi: 10.1109/ICASSP.2018.8462116.
6333
"""
6434

65-
is_differentiable = True
66-
sum_si_snr: Tensor
67-
total: Tensor
68-
higher_is_better = True
69-
7035
def __init__(
7136
self,
7237
compute_on_step: bool = True,
7338
dist_sync_on_step: bool = False,
7439
process_group: Optional[Any] = None,
7540
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
7641
) -> None:
77-
super().__init__(
78-
compute_on_step=compute_on_step,
79-
dist_sync_on_step=dist_sync_on_step,
80-
process_group=process_group,
81-
dist_sync_fn=dist_sync_fn,
42+
warn(
43+
"`SI_SNR` was renamed to `ScaleInvariantSignalNoiseRatio` in v0.7 and it will be removed in v0.8",
44+
DeprecationWarning,
8245
)
83-
84-
self.add_state("sum_si_snr", default=tensor(0.0), dist_reduce_fx="sum")
85-
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")
86-
87-
def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
88-
"""Update state with predictions and targets.
89-
90-
Args:
91-
preds: Predictions from model
92-
target: Ground truth values
93-
"""
94-
si_snr_batch = si_snr(preds=preds, target=target)
95-
96-
self.sum_si_snr += si_snr_batch.sum()
97-
self.total += si_snr_batch.numel()
98-
99-
def compute(self) -> Tensor:
100-
"""Computes average SI-SNR."""
101-
return self.sum_si_snr / self.total
46+
super().__init__(compute_on_step, dist_sync_on_step, process_group, dist_sync_fn)

0 commit comments

Comments
 (0)