Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/move args to kwargs #833

Merged
merged 43 commits into from
Feb 21, 2022
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
35fe10c
move kwargs
SkafteNicki Feb 9, 2022
4ba7728
update
SkafteNicki Feb 10, 2022
d8ac64d
link
SkafteNicki Feb 10, 2022
864e5f7
revert
SkafteNicki Feb 10, 2022
1ef093a
Merge branch 'master' into refactor/move_args_to_kwargs
SkafteNicki Feb 10, 2022
7086cb7
Apply suggestions from code review
Borda Feb 10, 2022
f0024fc
move static args
SkafteNicki Feb 11, 2022
1688868
typing
SkafteNicki Feb 11, 2022
065cde1
Update docs/source/pages/overview.rst
SkafteNicki Feb 11, 2022
d750ab3
Merge branch 'master' into refactor/move_args_to_kwargs
SkafteNicki Feb 11, 2022
9d318eb
include link in base
SkafteNicki Feb 11, 2022
1cd056a
audio
SkafteNicki Feb 16, 2022
4d216e4
wrappers
SkafteNicki Feb 16, 2022
c6598a7
detection
SkafteNicki Feb 16, 2022
ae1c5f7
image
SkafteNicki Feb 16, 2022
972dda7
retrieval
SkafteNicki Feb 16, 2022
d7fa5ae
regression
SkafteNicki Feb 16, 2022
659bb39
text
SkafteNicki Feb 16, 2022
4f74ca1
classification
SkafteNicki Feb 16, 2022
00ebd7b
ref
Borda Feb 16, 2022
6d6f9a7
Merge branch 'refactor/move_args_to_kwargs' of https://github.com/PyT…
SkafteNicki Feb 16, 2022
e739dfe
reference
SkafteNicki Feb 16, 2022
81f99ca
Merge branch 'master' into refactor/move_args_to_kwargs
SkafteNicki Feb 16, 2022
bed642d
changelog
SkafteNicki Feb 16, 2022
4f211e5
Merge branch 'refactor/move_args_to_kwargs' of https://github.com/PyT…
SkafteNicki Feb 16, 2022
803f11d
fix mypy
SkafteNicki Feb 16, 2022
53f9d13
fix inconsistency in args
SkafteNicki Feb 16, 2022
96a9fa0
remove restriction on process group
SkafteNicki Feb 16, 2022
b248cbb
rename to kwargs
SkafteNicki Feb 16, 2022
b9336ad
fix tests
SkafteNicki Feb 17, 2022
c5737e4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 17, 2022
d14294c
add base testing
SkafteNicki Feb 17, 2022
abee6e4
fix remaining examples
SkafteNicki Feb 17, 2022
7a67fb0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 17, 2022
ad20122
flake8
SkafteNicki Feb 17, 2022
75508e1
mypy
SkafteNicki Feb 17, 2022
8d3d807
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 17, 2022
147e279
Merge branch 'refactor/move_args_to_kwargs' of https://github.com/PyT…
SkafteNicki Feb 17, 2022
2c7b08d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 17, 2022
1a5f5f0
mypy
SkafteNicki Feb 17, 2022
9a184a2
fix tests
SkafteNicki Feb 17, 2022
fec6586
Merge branch 'master' into refactor/move_args_to_kwargs
SkafteNicki Feb 18, 2022
7d4388e
Merge branch 'master' into refactor/move_args_to_kwargs
SkafteNicki Feb 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `ClasswiseWrapper` for better logging of classification metrics with multiple output values ([#832](https://github.com/PyTorchLightning/metrics/pull/832))


- Added `**kwargs` argument for passing additional arguments to base class ([#833](https://github.com/PyTorchLightning/metrics/pull/833))


### Changed


### Deprecated

- Deprecated method `compute_on_step` ([#792](https://github.com/PyTorchLightning/metrics/pull/792))
- Deprecated argument `compute_on_step` ([#792](https://github.com/PyTorchLightning/metrics/pull/792))


- Deprecated passing in `dist_sync_on_step`, `process_group`, `dist_sync_fn` direct argument ([#833](https://github.com/PyTorchLightning/metrics/pull/833))


### Removed
Expand Down
21 changes: 21 additions & 0 deletions docs/source/pages/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -348,3 +348,24 @@ In practise this means that:
val = metric.compute() # this value cannot be back-propagated

A functional metric is differentiable if its corresponding modular metric is differentiable.

.. _Metric kwargs:

*****************************
Advanced distributed settings
*****************************

If you are running in a distributed environment, ``TorchMetrics`` will automatically take care of the distributed
synchronization for you. However, the following three keyword arguments can be given to any metric class for
further control over the distributed aggregation:

- ``dist_sync_on_step``: This argument is ``bool`` that indicates if the metric should syncronize between
different devices every time ``forward`` is called. Setting this to ``True`` is in general not recommended
as syncronization is an expensive operation to do after each batch.

- ``process_group``: By default we syncronize across the *world* i.e. all proceses being computed on. You
can provide an ``torch._C._distributed_c10d.ProcessGroup`` in this argument to specify exactly what
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
devices should be syncronized over.

- ``dist_sync_fn``: By default we use :func:`torch.distributed.all_gather` to perform the synchronization between
devices. Provide another callable function for this argument to perform custom distributed synchronization.
10 changes: 10 additions & 0 deletions tests/test_deprecated.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import pytest

from torchmetrics import Accuracy


def test_compute_on_step():
with pytest.warns(
DeprecationWarning, match="Argument `compute_on_step` is deprecated in v0.8 and will be removed in v0.9"
):
Accuracy(compute_on_step=False) # any metric will raise the warning
24 changes: 5 additions & 19 deletions torchmetrics/audio/pesq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional
from typing import Any, Dict, Optional

from torch import Tensor, tensor

Expand Down Expand Up @@ -49,15 +49,8 @@ class PerceptualEvaluationSpeechQuality(Metric):
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.

dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step
process_group:
Specify the process group on which synchronization is called.
default: ``None`` (which selects the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When ``None``, DDP
will be used to perform the allgather
kwargs:
Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Raises:
ModuleNotFoundError:
Expand Down Expand Up @@ -94,16 +87,9 @@ def __init__(
fs: int,
mode: str,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
**kwargs: Dict[str, Any],
) -> None:
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)
super().__init__(compute_on_step=compute_on_step, **kwargs)
if not _PESQ_AVAILABLE:
raise ModuleNotFoundError(
"PerceptualEvaluationSpeechQuality metric requires that `pesq` is installed."
Expand Down
21 changes: 3 additions & 18 deletions torchmetrics/audio/pit.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,9 @@ class PermutationInvariantTraining(Metric):
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.

dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
process_group:
Specify the process group on which synchronization is called.
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When `None`, DDP
will be used to perform the allgather.
kwargs:
additional args for metric_func
Additional keyword arguments for either the `metric_func` or distributed communication,
see :ref:`Metric kwargs` for more info.

Returns:
average PermutationInvariantTraining metric
Expand Down Expand Up @@ -83,17 +76,9 @@ def __init__(
metric_func: Callable,
eval_func: str = "max",
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
**kwargs: Dict[str, Any],
) -> None:
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)
super().__init__(compute_on_step=compute_on_step, **kwargs)
self.metric_func = metric_func
self.eval_func = eval_func
self.kwargs = kwargs
Expand Down
23 changes: 5 additions & 18 deletions torchmetrics/audio/sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional
from typing import Any, Callable, Dict, Optional

from torch import Tensor, tensor

Expand Down Expand Up @@ -53,14 +53,8 @@ class SignalDistortionRatio(Metric):
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.

dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
process_group:
Specify the process group on which synchronization is called.
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When `None`, DDP
will be used to perform the allgather.
kwargs:
Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Raises:
ModuleNotFoundError:
Expand Down Expand Up @@ -115,21 +109,14 @@ def __init__(
zero_mean: bool = False,
load_diag: Optional[float] = None,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
**kwargs: Dict[str, Any],
) -> None:
if not _FAST_BSS_EVAL_AVAILABLE:
raise ModuleNotFoundError(
"SDR metric requires that `fast-bss-eval` is installed."
" Either install as `pip install torchmetrics[audio]` or `pip install fast-bss-eval`."
)
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)
super().__init__(compute_on_step=compute_on_step, **kwargs)

self.use_cg_iter = use_cg_iter
self.filter_length = filter_length
Expand Down
23 changes: 5 additions & 18 deletions torchmetrics/audio/snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional
from typing import Any, Callable, Dict, Optional

from torch import Tensor, tensor

Expand Down Expand Up @@ -43,14 +43,8 @@ class SignalNoiseRatio(Metric):
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.

dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
process_group:
Specify the process group on which synchronization is called.
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When `None`, DDP
will be used to perform the allgather.
kwargs:
Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Raises:
TypeError:
Expand Down Expand Up @@ -82,16 +76,9 @@ def __init__(
self,
zero_mean: bool = False,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
**kwargs: Dict[str, Any],
) -> None:
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)
super().__init__(compute_on_step=compute_on_step, **kwargs)
self.zero_mean = zero_mean

self.add_state("sum_snr", default=tensor(0.0), dist_reduce_fx="sum")
Expand Down
23 changes: 5 additions & 18 deletions torchmetrics/audio/stoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional
from typing import Any, Dict, Optional

from torch import Tensor, tensor

Expand Down Expand Up @@ -52,14 +52,8 @@ class ShortTimeObjectiveIntelligibility(Metric):
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.

dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
process_group:
Specify the process group on which synchronization is called.
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When `None`, DDP
will be used to perform the allgather.
kwargs:
Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Returns:
average STOI value
Expand Down Expand Up @@ -101,16 +95,9 @@ def __init__(
fs: int,
extended: bool = False,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
**kwargs: Dict[str, Any],
) -> None:
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)
super().__init__(compute_on_step=compute_on_step, **kwargs)
if not _PYSTOI_AVAILABLE:
raise ModuleNotFoundError(
"STOI metric requires that `pystoi` is installed."
Expand Down
21 changes: 5 additions & 16 deletions torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional
from typing import Any, Dict, Optional

from torch import Tensor, tensor

Expand Down Expand Up @@ -132,15 +132,8 @@ class Accuracy(StatScores):
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.

dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step
process_group:
Specify the process group on which synchronization is called.
default: ``None`` (which selects the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When ``None``, DDP
will be used to perform the allgather
kwargs:
Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Raises:
ValueError:
Expand Down Expand Up @@ -184,9 +177,7 @@ def __init__(
multiclass: Optional[bool] = None,
subset_accuracy: bool = False,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
**kwargs: Dict[str, Any],
) -> None:
allowed_average = ["micro", "macro", "weighted", "samples", "none", None]
if average not in allowed_average:
Expand All @@ -201,9 +192,7 @@ def __init__(
multiclass=multiclass,
ignore_index=ignore_index,
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
**kwargs,
)

if top_k is not None and (not isinstance(top_k, int) or top_k <= 0):
Expand Down
23 changes: 5 additions & 18 deletions torchmetrics/classification/auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, List, Optional
from typing import Any, Dict, List, Optional

from torch import Tensor

Expand All @@ -38,14 +38,8 @@ class AUC(Metric):
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.

dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
process_group:
Specify the process group on which synchronization is called.
dist_sync_fn:
Callback that performs the ``allgather`` operation on the metric state. When ``None``, DDP
will be used to perform the ``allgather``.
kwargs:
Additional keyword arguments, see :ref:`Metric kwargs` for more info.
"""
is_differentiable = False
x: List[Tensor]
Expand All @@ -55,16 +49,9 @@ def __init__(
self,
reorder: bool = False,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
**kwargs: Dict[str, Any],
) -> None:
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)
super().__init__(compute_on_step=compute_on_step, **kwargs)

self.reorder = reorder

Expand Down
Loading