|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 | from typing import Any, Callable, Optional
|
| 15 | +from warnings import warn |
15 | 16 |
|
16 |
| -from torch import Tensor, tensor |
| 17 | +from torch import Tensor |
17 | 18 |
|
18 |
| -from torchmetrics.functional.audio.si_snr import si_snr |
19 |
| -from torchmetrics.metric import Metric |
| 19 | +from torchmetrics.audio.snr import ScaleInvariantSignalNoiseRatio |
20 | 20 |
|
21 | 21 |
|
22 |
| -class SI_SNR(Metric): |
| 22 | +class SI_SNR(ScaleInvariantSignalNoiseRatio): |
23 | 23 | """Scale-invariant signal-to-noise ratio (SI-SNR).
|
24 | 24 |
|
25 |
| - Forward accepts |
26 |
| -
|
27 |
| - - ``preds``: ``shape [...,time]`` |
28 |
| - - ``target``: ``shape [...,time]`` |
29 |
| -
|
30 |
| - Args: |
31 |
| - compute_on_step: |
32 |
| - Forward only calls ``update()`` and returns None if this is set to False. |
33 |
| - dist_sync_on_step: |
34 |
| - Synchronize metric state across processes at each ``forward()`` |
35 |
| - before returning the value at the step. |
36 |
| - process_group: |
37 |
| - Specify the process group on which synchronization is called. |
38 |
| - dist_sync_fn: |
39 |
| - Callback that performs the allgather operation on the metric state. When `None`, DDP |
40 |
| - will be used to perform the allgather. |
41 |
| -
|
42 |
| - Raises: |
43 |
| - TypeError: |
44 |
| - if target and preds have a different shape |
45 |
| -
|
46 |
| - Returns: |
47 |
| - average si-snr value |
| 25 | + .. deprecated:: v0.7 |
| 26 | + Use :class:`torchmetrics.ScaleInvariantSignalNoiseRatio`. Will be removed in v0.8. |
48 | 27 |
|
49 | 28 | Example:
|
50 | 29 | >>> import torch
|
51 |
| - >>> from torchmetrics import SI_SNR |
52 |
| - >>> target = torch.tensor([3.0, -0.5, 2.0, 7.0]) |
53 |
| - >>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0]) |
54 | 30 | >>> si_snr = SI_SNR()
|
55 |
| - >>> si_snr_val = si_snr(preds, target) |
56 |
| - >>> si_snr_val |
| 31 | + >>> si_snr(torch.tensor([2.5, 0.0, 2.0, 8.0]), torch.tensor([3.0, -0.5, 2.0, 7.0])) |
57 | 32 | tensor(15.0918)
|
58 |
| -
|
59 |
| - References: |
60 |
| - [1] Y. Luo and N. Mesgarani, "TaSNet: Time-Domain Audio Separation Network for Real-Time, Single-Channel Speech |
61 |
| - Separation," 2018 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), 2018, pp. |
62 |
| - 696-700, doi: 10.1109/ICASSP.2018.8462116. |
63 | 33 | """
|
64 | 34 |
|
65 |
| - is_differentiable = True |
66 |
| - sum_si_snr: Tensor |
67 |
| - total: Tensor |
68 |
| - higher_is_better = True |
69 |
| - |
70 | 35 | def __init__(
|
71 | 36 | self,
|
72 | 37 | compute_on_step: bool = True,
|
73 | 38 | dist_sync_on_step: bool = False,
|
74 | 39 | process_group: Optional[Any] = None,
|
75 | 40 | dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
|
76 | 41 | ) -> None:
|
77 |
| - super().__init__( |
78 |
| - compute_on_step=compute_on_step, |
79 |
| - dist_sync_on_step=dist_sync_on_step, |
80 |
| - process_group=process_group, |
81 |
| - dist_sync_fn=dist_sync_fn, |
| 42 | + warn( |
| 43 | + "`SI_SNR` was renamed to `ScaleInvariantSignalNoiseRatio` in v0.7 and it will be removed in v0.8", |
| 44 | + DeprecationWarning, |
82 | 45 | )
|
83 |
| - |
84 |
| - self.add_state("sum_si_snr", default=tensor(0.0), dist_reduce_fx="sum") |
85 |
| - self.add_state("total", default=tensor(0), dist_reduce_fx="sum") |
86 |
| - |
87 |
| - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore |
88 |
| - """Update state with predictions and targets. |
89 |
| -
|
90 |
| - Args: |
91 |
| - preds: Predictions from model |
92 |
| - target: Ground truth values |
93 |
| - """ |
94 |
| - si_snr_batch = si_snr(preds=preds, target=target) |
95 |
| - |
96 |
| - self.sum_si_snr += si_snr_batch.sum() |
97 |
| - self.total += si_snr_batch.numel() |
98 |
| - |
99 |
| - def compute(self) -> Tensor: |
100 |
| - """Computes average SI-SNR.""" |
101 |
| - return self.sum_si_snr / self.total |
| 46 | + super().__init__(compute_on_step, dist_sync_on_step, process_group, dist_sync_fn) |
0 commit comments