Skip to content

Commit

Permalink
Refactor: move args to kwargs (#833)
Browse files Browse the repository at this point in the history
* move static args
* include link in base
+ audio
+ wrappers
+ detection
+ image
+ retrieval
+ regression
+ text
+ classification
* fix inconsistency in args
* remove restriction on process group
* rename to kwargs
* add base testing
* mypy

Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Justus Schock <[email protected]>
Co-authored-by: Jirka <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
5 people authored Feb 21, 2022
1 parent b87c3b8 commit 478576e
Show file tree
Hide file tree
Showing 76 changed files with 590 additions and 1,346 deletions.
8 changes: 7 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `ClasswiseWrapper` for better logging of classification metrics with multiple output values ([#832](https://github.com/PyTorchLightning/metrics/pull/832))


- Added `**kwargs` argument for passing additional arguments to base class ([#833](https://github.com/PyTorchLightning/metrics/pull/833))


### Changed


### Deprecated

- Deprecated method `compute_on_step` ([#792](https://github.com/PyTorchLightning/metrics/pull/792))
- Deprecated argument `compute_on_step` ([#792](https://github.com/PyTorchLightning/metrics/pull/792))


- Deprecated passing in `dist_sync_on_step`, `process_group`, `dist_sync_fn` direct argument ([#833](https://github.com/PyTorchLightning/metrics/pull/833))


### Removed
Expand Down
21 changes: 21 additions & 0 deletions docs/source/pages/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -348,3 +348,24 @@ In practise this means that:
val = metric.compute() # this value cannot be back-propagated
A functional metric is differentiable if its corresponding modular metric is differentiable.

.. _Metric kwargs:

*****************************
Advanced distributed settings
*****************************

If you are running in a distributed environment, ``TorchMetrics`` will automatically take care of the distributed
synchronization for you. However, the following three keyword arguments can be given to any metric class for
further control over the distributed aggregation:

- ``dist_sync_on_step``: This argument is ``bool`` that indicates if the metric should syncronize between
different devices every time ``forward`` is called. Setting this to ``True`` is in general not recommended
as syncronization is an expensive operation to do after each batch.

- ``process_group``: By default we syncronize across the *world* i.e. all proceses being computed on. You
can provide an ``torch._C._distributed_c10d.ProcessGroup`` in this argument to specify exactly what
devices should be syncronized over.

- ``dist_sync_fn``: By default we use :func:`torch.distributed.all_gather` to perform the synchronization between
devices. Provide another callable function for this argument to perform custom distributed synchronization.
11 changes: 11 additions & 0 deletions tests/bases/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,22 @@
seed_all(42)


def test_error_on_wrong_input():
"""Test that base metric class raises error on wrong input types."""
with pytest.raises(ValueError, match="Expected keyword argument `dist_sync_on_step` to be an `bool` but.*"):
DummyMetric(dist_sync_on_step=None)

with pytest.raises(ValueError, match="Expected keyword argument `dist_sync_fn` to be an callable function.*"):
DummyMetric(dist_sync_fn=[2, 3])


def test_inherit():
"""Test that metric that inherits can be instanciated."""
DummyMetric()


def test_add_state():
"""Test that add state method works as expected."""
a = DummyMetric()

a.add_state("a", tensor(0), "sum")
Expand Down
4 changes: 2 additions & 2 deletions tests/helpers/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,8 +563,8 @@ def run_differentiability_test(
class DummyMetric(Metric):
name = "Dummy"

def __init__(self):
super().__init__()
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.add_state("x", tensor(0.0), dist_reduce_fx=None)

def update(self):
Expand Down
10 changes: 10 additions & 0 deletions tests/test_deprecated.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import pytest

from torchmetrics import Accuracy


def test_compute_on_step():
with pytest.warns(
DeprecationWarning, match="Argument `compute_on_step` is deprecated in v0.8 and will be removed in v0.9"
):
Accuracy(compute_on_step=False) # any metric will raise the warning
25 changes: 3 additions & 22 deletions tests/wrappers/test_multioutput.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from collections import namedtuple
from functools import partial
from typing import Any, Callable, Optional

import pytest
import torch
Expand All @@ -24,30 +23,12 @@ def __init__(
self,
base_metric_class,
num_outputs: int = 1,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Any = None,
dist_sync_fn: Optional[Callable] = None,
**base_metric_kwargs,
**kwargs,
) -> None:
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)
super().__init__(**kwargs)
self.metric = MultioutputWrapper(
base_metric_class(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
**base_metric_kwargs,
),
base_metric_class(**kwargs),
num_outputs=num_outputs,
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
dist_sync_fn=dist_sync_fn,
)

def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
Expand Down
115 changes: 33 additions & 82 deletions torchmetrics/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Any, Callable, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union

import torch
from torch import Tensor
Expand All @@ -39,14 +39,8 @@ class BaseAggregator(Metric):
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
process_group:
Specify the process group on which synchronization is called.
dist_sync_fn:
Callback that performs the allgather operation on the metric state.
When `None`, DDP will be used to perform the allgather.
kwargs:
Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
ValueError:
Expand All @@ -63,16 +57,9 @@ def __init__(
default_value: Union[Tensor, List],
nan_strategy: Union[str, float] = "error",
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
**kwargs: Dict[str, Any],
):
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)
super().__init__(compute_on_step=compute_on_step, **kwargs)
allowed_nan_strategy = ("error", "warn", "ignore")
if nan_strategy not in allowed_nan_strategy and not isinstance(nan_strategy, float):
raise ValueError(
Expand Down Expand Up @@ -128,14 +115,8 @@ class MaxMetric(BaseAggregator):
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
process_group:
Specify the process group on which synchronization is called.
dist_sync_fn:
Callback that performs the allgather operation on the metric state.
When `None`, DDP will be used to perform the allgather.
kwargs:
Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
ValueError:
Expand All @@ -154,18 +135,14 @@ def __init__(
self,
nan_strategy: Union[str, float] = "warn",
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
**kwargs: Dict[str, Any],
):
super().__init__(
"max",
-torch.tensor(float("inf")),
nan_strategy,
compute_on_step,
dist_sync_on_step,
process_group,
dist_sync_fn,
**kwargs,
)

def update(self, value: Union[float, Tensor]) -> None: # type: ignore
Expand Down Expand Up @@ -196,14 +173,8 @@ class MinMetric(BaseAggregator):
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
process_group:
Specify the process group on which synchronization is called.
dist_sync_fn:
Callback that performs the allgather operation on the metric state.
When `None`, DDP will be used to perform the allgather.
kwargs:
Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
ValueError:
Expand All @@ -222,18 +193,14 @@ def __init__(
self,
nan_strategy: Union[str, float] = "warn",
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
**kwargs: Dict[str, Any],
):
super().__init__(
"min",
torch.tensor(float("inf")),
nan_strategy,
compute_on_step,
dist_sync_on_step,
process_group,
dist_sync_fn,
**kwargs,
)

def update(self, value: Union[float, Tensor]) -> None: # type: ignore
Expand Down Expand Up @@ -264,14 +231,8 @@ class SumMetric(BaseAggregator):
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
process_group:
Specify the process group on which synchronization is called.
dist_sync_fn:
Callback that performs the allgather operation on the metric state.
When `None`, DDP will be used to perform the allgather.
kwargs:
Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
ValueError:
Expand All @@ -290,12 +251,14 @@ def __init__(
self,
nan_strategy: Union[str, float] = "warn",
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
**kwargs: Dict[str, Any],
):
super().__init__(
"sum", torch.tensor(0.0), nan_strategy, compute_on_step, dist_sync_on_step, process_group, dist_sync_fn
"sum",
torch.tensor(0.0),
nan_strategy,
compute_on_step,
**kwargs,
)

def update(self, value: Union[float, Tensor]) -> None: # type: ignore
Expand Down Expand Up @@ -325,14 +288,8 @@ class CatMetric(BaseAggregator):
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
process_group:
Specify the process group on which synchronization is called.
dist_sync_fn:
Callback that performs the allgather operation on the metric state.
When `None`, DDP will be used to perform the allgather.
kwargs:
Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
ValueError:
Expand All @@ -351,11 +308,9 @@ def __init__(
self,
nan_strategy: Union[str, float] = "warn",
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
**kwargs: Dict[str, Any],
):
super().__init__("cat", [], nan_strategy, compute_on_step, dist_sync_on_step, process_group, dist_sync_fn)
super().__init__("cat", [], nan_strategy, compute_on_step, **kwargs)

def update(self, value: Union[float, Tensor]) -> None: # type: ignore
"""Update state with data.
Expand Down Expand Up @@ -391,14 +346,8 @@ class MeanMetric(BaseAggregator):
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
process_group:
Specify the process group on which synchronization is called.
dist_sync_fn:
Callback that performs the allgather operation on the metric state.
When `None`, DDP will be used to perform the allgather.
kwargs:
Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
ValueError:
Expand All @@ -417,12 +366,14 @@ def __init__(
self,
nan_strategy: Union[str, float] = "warn",
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
**kwargs: Dict[str, Any],
):
super().__init__(
"sum", torch.tensor(0.0), nan_strategy, compute_on_step, dist_sync_on_step, process_group, dist_sync_fn
"sum",
torch.tensor(0.0),
nan_strategy,
compute_on_step,
**kwargs,
)
self.add_state("weight", default=torch.tensor(0.0), dist_reduce_fx="sum")

Expand Down
Loading

0 comments on commit 478576e

Please sign in to comment.