Skip to content

Commit d3cab93

Browse files
authored
refactor: pyDeprecate (#745)
* deprecate: class * deprecate: function * cleaning
1 parent 83345d7 commit d3cab93

28 files changed

+118
-184
lines changed

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
numpy>=1.17.2
22
torch>=1.3.1
3+
pyDeprecate==0.3.*
34
packaging

torchmetrics/audio/pit.py

+3-8
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from typing import Any, Callable, Dict, Optional
15-
from warnings import warn
1615

16+
from deprecate import deprecated, void
1717
from torch import Tensor, tensor
1818

1919
from torchmetrics.functional.audio.pit import permutation_invariant_training
@@ -137,6 +137,7 @@ class PIT(PermutationInvariantTraining):
137137
Signal Process. ICASSP, IEEE, New Orleans, LA, 2017: pp. 241–245. https://doi.org/10.1109/ICASSP.2017.7952154.
138138
"""
139139

140+
@deprecated(target=PermutationInvariantTraining, deprecated_in="0.7", remove_in="0.8")
140141
def __init__(
141142
self,
142143
metric_func: Callable,
@@ -147,10 +148,4 @@ def __init__(
147148
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
148149
**kwargs: Dict[str, Any],
149150
) -> None:
150-
warn(
151-
"`PIT` was renamed to `PermutationInvariantTraining` in v0.7 and it will be removed in v0.8",
152-
DeprecationWarning,
153-
)
154-
super().__init__(
155-
metric_func, eval_func, compute_on_step, dist_sync_on_step, process_group, dist_sync_fn, **kwargs
156-
)
151+
void(metric_func, eval_func, compute_on_step, dist_sync_on_step, process_group, dist_sync_fn, **kwargs)

torchmetrics/audio/sdr.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from typing import Any, Callable, Optional
15-
from warnings import warn
1615

16+
from deprecate import deprecated, void
1717
from torch import Tensor, tensor
1818

1919
from torchmetrics.functional.audio.sdr import scale_invariant_signal_distortion_ratio, signal_distortion_ratio
@@ -173,6 +173,7 @@ class SDR(SignalDistortionRatio):
173173
tensor(-11.6051)
174174
"""
175175

176+
@deprecated(target=SignalDistortionRatio, deprecated_in="0.7", remove_in="0.8")
176177
def __init__(
177178
self,
178179
use_cg_iter: Optional[int] = None,
@@ -184,9 +185,7 @@ def __init__(
184185
process_group: Optional[Any] = None,
185186
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
186187
) -> None:
187-
warn("`SDR` was renamed to `SignalDistortionRatio` in v0.7 and it will be removed in v0.8", DeprecationWarning)
188-
189-
super().__init__(
188+
void(
190189
use_cg_iter,
191190
filter_length,
192191
zero_mean,

torchmetrics/audio/si_sdr.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from typing import Any, Callable, Optional
15-
from warnings import warn
1615

16+
from deprecate import deprecated, void
1717
from torch import Tensor
1818

1919
from torchmetrics.audio.sdr import ScaleInvariantSignalDistortionRatio
@@ -32,6 +32,7 @@ class SI_SDR(ScaleInvariantSignalDistortionRatio):
3232
tensor(18.4030)
3333
"""
3434

35+
@deprecated(target=ScaleInvariantSignalDistortionRatio, deprecated_in="0.7", remove_in="0.8")
3536
def __init__(
3637
self,
3738
zero_mean: bool = False,
@@ -40,8 +41,4 @@ def __init__(
4041
process_group: Optional[Any] = None,
4142
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
4243
) -> None:
43-
warn(
44-
"`SI_SDR` was renamed to `ScaleInvariantSignalDistortionRatio` in v0.7 and it will be removed in v0.8",
45-
DeprecationWarning,
46-
)
47-
super().__init__(zero_mean, compute_on_step, dist_sync_on_step, process_group, dist_sync_fn)
44+
void(zero_mean, compute_on_step, dist_sync_on_step, process_group, dist_sync_fn)

torchmetrics/audio/si_snr.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from typing import Any, Callable, Optional
15-
from warnings import warn
1615

16+
from deprecate import deprecated, void
1717
from torch import Tensor
1818

1919
from torchmetrics.audio.snr import ScaleInvariantSignalNoiseRatio
@@ -32,15 +32,12 @@ class SI_SNR(ScaleInvariantSignalNoiseRatio):
3232
tensor(15.0918)
3333
"""
3434

35+
@deprecated(target=ScaleInvariantSignalNoiseRatio, deprecated_in="0.7", remove_in="0.8")
3536
def __init__(
3637
self,
3738
compute_on_step: bool = True,
3839
dist_sync_on_step: bool = False,
3940
process_group: Optional[Any] = None,
4041
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
4142
) -> None:
42-
warn(
43-
"`SI_SNR` was renamed to `ScaleInvariantSignalNoiseRatio` in v0.7 and it will be removed in v0.8",
44-
DeprecationWarning,
45-
)
46-
super().__init__(compute_on_step, dist_sync_on_step, process_group, dist_sync_fn)
43+
void(compute_on_step, dist_sync_on_step, process_group, dist_sync_fn)

torchmetrics/audio/snr.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from typing import Any, Callable, Optional
15-
from warnings import warn
1615

16+
from deprecate import deprecated, void
1717
from torch import Tensor, tensor
1818

1919
from torchmetrics.functional.audio.snr import scale_invariant_signal_noise_ratio, snr
@@ -125,6 +125,7 @@ class SNR(SignalNoiseRatio):
125125
126126
"""
127127

128+
@deprecated(target=SignalNoiseRatio, deprecated_in="0.7", remove_in="0.8")
128129
def __init__(
129130
self,
130131
zero_mean: bool = False,
@@ -133,8 +134,7 @@ def __init__(
133134
process_group: Optional[Any] = None,
134135
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
135136
) -> None:
136-
warn("`SNR` was renamed to `SignalNoiseRatio` in v0.7 and it will be removed in v0.8", DeprecationWarning)
137-
super().__init__(zero_mean, compute_on_step, dist_sync_on_step, process_group, dist_sync_fn)
137+
void(zero_mean, compute_on_step, dist_sync_on_step, process_group, dist_sync_fn)
138138

139139

140140
class ScaleInvariantSignalNoiseRatio(Metric):

torchmetrics/classification/f_beta.py

+28-31
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from typing import Any, Callable, Optional
15-
from warnings import warn
1615

1716
import torch
17+
from deprecate import deprecated, void
1818
from torch import Tensor
1919

2020
from torchmetrics.classification.stat_scores import StatScores
@@ -188,6 +188,7 @@ class FBeta(FBetaScore):
188188
tensor(0.3333)
189189
"""
190190

191+
@deprecated(target=FBetaScore, deprecated_in="0.7", remove_in="0.8")
191192
def __init__(
192193
self,
193194
num_classes: Optional[int] = None,
@@ -203,20 +204,19 @@ def __init__(
203204
process_group: Optional[Any] = None,
204205
dist_sync_fn: Callable = None,
205206
) -> None:
206-
warn("`FBeta` was renamed to `FBetaScore` in v0.7 and it will be removed in v0.8", DeprecationWarning)
207-
super().__init__(
208-
num_classes=num_classes,
209-
beta=beta,
210-
threshold=threshold,
211-
average=average,
212-
mdmc_average=mdmc_average,
213-
ignore_index=ignore_index,
214-
top_k=top_k,
215-
multiclass=multiclass,
216-
compute_on_step=compute_on_step,
217-
dist_sync_on_step=dist_sync_on_step,
218-
process_group=process_group,
219-
dist_sync_fn=dist_sync_fn,
207+
void(
208+
num_classes,
209+
beta,
210+
threshold,
211+
average,
212+
mdmc_average,
213+
ignore_index,
214+
top_k,
215+
multiclass,
216+
compute_on_step,
217+
dist_sync_on_step,
218+
process_group,
219+
dist_sync_fn,
220220
)
221221

222222

@@ -363,9 +363,7 @@ class F1(F1Score):
363363
tensor(0.3333)
364364
"""
365365

366-
is_differentiable = False
367-
higher_is_better = True
368-
366+
@deprecated(target=F1Score, deprecated_in="0.7", remove_in="0.8")
369367
def __init__(
370368
self,
371369
num_classes: Optional[int] = None,
@@ -380,17 +378,16 @@ def __init__(
380378
process_group: Optional[Any] = None,
381379
dist_sync_fn: Callable = None,
382380
) -> None:
383-
warn("`F1` was renamed to `F1Score` in v0.7 and it will be removed in v0.8", DeprecationWarning)
384-
super().__init__(
385-
num_classes=num_classes,
386-
threshold=threshold,
387-
average=average,
388-
mdmc_average=mdmc_average,
389-
ignore_index=ignore_index,
390-
top_k=top_k,
391-
multiclass=multiclass,
392-
compute_on_step=compute_on_step,
393-
dist_sync_on_step=dist_sync_on_step,
394-
process_group=process_group,
395-
dist_sync_fn=dist_sync_fn,
381+
void(
382+
num_classes,
383+
threshold,
384+
average,
385+
mdmc_average,
386+
ignore_index,
387+
top_k,
388+
multiclass,
389+
compute_on_step,
390+
dist_sync_on_step,
391+
process_group,
392+
dist_sync_fn,
396393
)

torchmetrics/classification/hinge.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from typing import Any, Callable, Optional, Union
15-
from warnings import warn
1615

16+
from deprecate import deprecated, void
1717
from torch import Tensor, tensor
1818

1919
from torchmetrics.functional.classification.hinge import MulticlassMode, _hinge_compute, _hinge_update
@@ -154,6 +154,7 @@ class Hinge(HingeLoss):
154154
155155
"""
156156

157+
@deprecated(target=HingeLoss, deprecated_in="0.7", remove_in="0.8")
157158
def __init__(
158159
self,
159160
squared: bool = False,
@@ -163,5 +164,4 @@ def __init__(
163164
process_group: Optional[Any] = None,
164165
dist_sync_fn: Callable = None,
165166
) -> None:
166-
warn("`Hinge` was renamed to `HingeLoss` in v0.7 and it will be removed in v0.8", DeprecationWarning)
167-
super().__init__(squared, multiclass_mode, compute_on_step, dist_sync_on_step, process_group, dist_sync_fn)
167+
void(squared, multiclass_mode, compute_on_step, dist_sync_on_step, process_group, dist_sync_fn)

torchmetrics/classification/iou.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from typing import Any, Optional
15-
from warnings import warn
1615

1716
import torch
17+
from deprecate import deprecated, void
1818

1919
from torchmetrics.classification.jaccard import JaccardIndex
2020

@@ -37,6 +37,7 @@ class IoU(JaccardIndex):
3737
3838
"""
3939

40+
@deprecated(target=JaccardIndex, deprecated_in="0.7", remove_in="0.8")
4041
def __init__(
4142
self,
4243
num_classes: int,
@@ -48,14 +49,13 @@ def __init__(
4849
dist_sync_on_step: bool = False,
4950
process_group: Optional[Any] = None,
5051
) -> None:
51-
warn("`IoU` was renamed to `JaccardIndex` in v0.7 and it will be removed in v0.8", DeprecationWarning)
52-
super().__init__(
53-
num_classes=num_classes,
54-
ignore_index=ignore_index,
55-
absent_score=absent_score,
56-
threshold=threshold,
57-
reduction=reduction,
58-
compute_on_step=compute_on_step,
59-
dist_sync_on_step=dist_sync_on_step,
60-
process_group=process_group,
52+
void(
53+
num_classes,
54+
ignore_index,
55+
absent_score,
56+
threshold,
57+
reduction,
58+
compute_on_step,
59+
dist_sync_on_step,
60+
process_group,
6161
)

torchmetrics/classification/matthews_corrcoef.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from typing import Any, Callable, Optional
15-
from warnings import warn
1615

1716
import torch
17+
from deprecate import deprecated, void
1818
from torch import Tensor
1919

2020
from torchmetrics.functional.classification.matthews_corrcoef import (
@@ -126,6 +126,7 @@ class MatthewsCorrcoef(MatthewsCorrCoef):
126126
Renamed in favor of :class:`torchmetrics.MatthewsCorrCoef`. Will be removed in v0.8.
127127
"""
128128

129+
@deprecated(target=MatthewsCorrCoef, deprecated_in="0.7", remove_in="0.8")
129130
def __init__(
130131
self,
131132
num_classes: int,
@@ -135,8 +136,4 @@ def __init__(
135136
process_group: Optional[Any] = None,
136137
dist_sync_fn: Callable = None,
137138
) -> None:
138-
warn(
139-
"`MatthewsCorrcoef` was renamed to `MatthewsCorrCoef` in v0.7 and it will be removed in v0.8",
140-
DeprecationWarning,
141-
)
142-
super().__init__(num_classes, threshold, compute_on_step, dist_sync_on_step, process_group, dist_sync_fn)
139+
void(num_classes, threshold, compute_on_step, dist_sync_on_step, process_group, dist_sync_fn)

torchmetrics/functional/audio/pit.py

+3-10
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from warnings import warn
1717

1818
import torch
19+
from deprecate import deprecated, void
1920
from torch import Tensor
2021

2122
from torchmetrics.utilities.checks import _check_same_shape
@@ -178,13 +179,12 @@ def permutation_invariant_training(
178179
return best_metric, best_perm
179180

180181

182+
@deprecated(target=permutation_invariant_training, deprecated_in="0.7", remove_in="0.8")
181183
def pit(
182184
preds: torch.Tensor, target: torch.Tensor, metric_func: Callable, eval_func: str = "max", **kwargs: Dict[str, Any]
183185
) -> Tuple[Tensor, Tensor]:
184186
"""Permutation invariant training. The ``pit`` implements the famous Permutation Invariant Training method.
185187
186-
[1] in speech separation field in order to calculate audio metrics in a permutation invariant way.
187-
188188
.. deprecated:: v0.7
189189
Use :func:`torchmetrics.functional.permutation_invariant_training`. Will be removed in v0.8.
190190
@@ -202,15 +202,8 @@ def pit(
202202
>>> pit_permutate(preds, best_perm)
203203
tensor([[[-0.0579, 0.3560, -0.9604],
204204
[-0.1719, 0.3205, 0.2951]]])
205-
206-
Reference:
207-
[1] `Permutation Invariant Training of Deep Models`_
208205
"""
209-
warn(
210-
"`pit` was renamed to `permutation_invariant_training` in v0.7 and it will be removed in v0.8",
211-
DeprecationWarning,
212-
)
213-
return permutation_invariant_training(preds, target, metric_func, eval_func, **kwargs)
206+
return void(preds, target, metric_func, eval_func, **kwargs)
214207

215208

216209
def pit_permutate(preds: Tensor, perm: Tensor) -> Tensor:

0 commit comments

Comments
 (0)