Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/move args to kwargs #833

Merged
merged 43 commits into from
Feb 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
35fe10c
move kwargs
SkafteNicki Feb 9, 2022
4ba7728
update
SkafteNicki Feb 10, 2022
d8ac64d
link
SkafteNicki Feb 10, 2022
864e5f7
revert
SkafteNicki Feb 10, 2022
1ef093a
Merge branch 'master' into refactor/move_args_to_kwargs
SkafteNicki Feb 10, 2022
7086cb7
Apply suggestions from code review
Borda Feb 10, 2022
f0024fc
move static args
SkafteNicki Feb 11, 2022
1688868
typing
SkafteNicki Feb 11, 2022
065cde1
Update docs/source/pages/overview.rst
SkafteNicki Feb 11, 2022
d750ab3
Merge branch 'master' into refactor/move_args_to_kwargs
SkafteNicki Feb 11, 2022
9d318eb
include link in base
SkafteNicki Feb 11, 2022
1cd056a
audio
SkafteNicki Feb 16, 2022
4d216e4
wrappers
SkafteNicki Feb 16, 2022
c6598a7
detection
SkafteNicki Feb 16, 2022
ae1c5f7
image
SkafteNicki Feb 16, 2022
972dda7
retrieval
SkafteNicki Feb 16, 2022
d7fa5ae
regression
SkafteNicki Feb 16, 2022
659bb39
text
SkafteNicki Feb 16, 2022
4f74ca1
classification
SkafteNicki Feb 16, 2022
00ebd7b
ref
Borda Feb 16, 2022
6d6f9a7
Merge branch 'refactor/move_args_to_kwargs' of https://github.com/PyT…
SkafteNicki Feb 16, 2022
e739dfe
reference
SkafteNicki Feb 16, 2022
81f99ca
Merge branch 'master' into refactor/move_args_to_kwargs
SkafteNicki Feb 16, 2022
bed642d
changelog
SkafteNicki Feb 16, 2022
4f211e5
Merge branch 'refactor/move_args_to_kwargs' of https://github.com/PyT…
SkafteNicki Feb 16, 2022
803f11d
fix mypy
SkafteNicki Feb 16, 2022
53f9d13
fix inconsistency in args
SkafteNicki Feb 16, 2022
96a9fa0
remove restriction on process group
SkafteNicki Feb 16, 2022
b248cbb
rename to kwargs
SkafteNicki Feb 16, 2022
b9336ad
fix tests
SkafteNicki Feb 17, 2022
c5737e4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 17, 2022
d14294c
add base testing
SkafteNicki Feb 17, 2022
abee6e4
fix remaining examples
SkafteNicki Feb 17, 2022
7a67fb0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 17, 2022
ad20122
flake8
SkafteNicki Feb 17, 2022
75508e1
mypy
SkafteNicki Feb 17, 2022
8d3d807
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 17, 2022
147e279
Merge branch 'refactor/move_args_to_kwargs' of https://github.com/PyT…
SkafteNicki Feb 17, 2022
2c7b08d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 17, 2022
1a5f5f0
mypy
SkafteNicki Feb 17, 2022
9a184a2
fix tests
SkafteNicki Feb 17, 2022
fec6586
Merge branch 'master' into refactor/move_args_to_kwargs
SkafteNicki Feb 18, 2022
7d4388e
Merge branch 'master' into refactor/move_args_to_kwargs
SkafteNicki Feb 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `ClasswiseWrapper` for better logging of classification metrics with multiple output values ([#832](https://github.com/PyTorchLightning/metrics/pull/832))


- Added `**kwargs` argument for passing additional arguments to base class ([#833](https://github.com/PyTorchLightning/metrics/pull/833))


### Changed


### Deprecated

- Deprecated method `compute_on_step` ([#792](https://github.com/PyTorchLightning/metrics/pull/792))
- Deprecated argument `compute_on_step` ([#792](https://github.com/PyTorchLightning/metrics/pull/792))


- Deprecated passing in `dist_sync_on_step`, `process_group`, `dist_sync_fn` direct argument ([#833](https://github.com/PyTorchLightning/metrics/pull/833))


### Removed
Expand Down
21 changes: 21 additions & 0 deletions docs/source/pages/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -348,3 +348,24 @@ In practise this means that:
val = metric.compute() # this value cannot be back-propagated

A functional metric is differentiable if its corresponding modular metric is differentiable.

.. _Metric kwargs:

*****************************
Advanced distributed settings
*****************************

If you are running in a distributed environment, ``TorchMetrics`` will automatically take care of the distributed
synchronization for you. However, the following three keyword arguments can be given to any metric class for
further control over the distributed aggregation:

- ``dist_sync_on_step``: This argument is ``bool`` that indicates if the metric should syncronize between
different devices every time ``forward`` is called. Setting this to ``True`` is in general not recommended
as syncronization is an expensive operation to do after each batch.

- ``process_group``: By default we syncronize across the *world* i.e. all proceses being computed on. You
can provide an ``torch._C._distributed_c10d.ProcessGroup`` in this argument to specify exactly what
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
devices should be syncronized over.

- ``dist_sync_fn``: By default we use :func:`torch.distributed.all_gather` to perform the synchronization between
devices. Provide another callable function for this argument to perform custom distributed synchronization.
11 changes: 11 additions & 0 deletions tests/bases/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,22 @@
seed_all(42)


def test_error_on_wrong_input():
"""Test that base metric class raises error on wrong input types."""
with pytest.raises(ValueError, match="Expected keyword argument `dist_sync_on_step` to be an `bool` but.*"):
DummyMetric(dist_sync_on_step=None)

with pytest.raises(ValueError, match="Expected keyword argument `dist_sync_fn` to be an callable function.*"):
DummyMetric(dist_sync_fn=[2, 3])


def test_inherit():
"""Test that metric that inherits can be instanciated."""
DummyMetric()


def test_add_state():
"""Test that add state method works as expected."""
a = DummyMetric()

a.add_state("a", tensor(0), "sum")
Expand Down
4 changes: 2 additions & 2 deletions tests/helpers/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,8 +563,8 @@ def run_differentiability_test(
class DummyMetric(Metric):
name = "Dummy"

def __init__(self):
super().__init__()
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.add_state("x", tensor(0.0), dist_reduce_fx=None)

def update(self):
Expand Down
10 changes: 10 additions & 0 deletions tests/test_deprecated.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import pytest

from torchmetrics import Accuracy


def test_compute_on_step():
with pytest.warns(
DeprecationWarning, match="Argument `compute_on_step` is deprecated in v0.8 and will be removed in v0.9"
):
Accuracy(compute_on_step=False) # any metric will raise the warning
25 changes: 3 additions & 22 deletions tests/wrappers/test_multioutput.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from collections import namedtuple
from functools import partial
from typing import Any, Callable, Optional

import pytest
import torch
Expand All @@ -24,30 +23,12 @@ def __init__(
self,
base_metric_class,
num_outputs: int = 1,
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Any = None,
dist_sync_fn: Optional[Callable] = None,
**base_metric_kwargs,
**kwargs,
) -> None:
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)
super().__init__(**kwargs)
self.metric = MultioutputWrapper(
base_metric_class(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
**base_metric_kwargs,
),
base_metric_class(**kwargs),
num_outputs=num_outputs,
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
dist_sync_fn=dist_sync_fn,
)

def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
Expand Down
115 changes: 33 additions & 82 deletions torchmetrics/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Any, Callable, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union

import torch
from torch import Tensor
Expand All @@ -39,14 +39,8 @@ class BaseAggregator(Metric):
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.

dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
process_group:
Specify the process group on which synchronization is called.
dist_sync_fn:
Callback that performs the allgather operation on the metric state.
When `None`, DDP will be used to perform the allgather.
kwargs:
Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Raises:
ValueError:
Expand All @@ -63,16 +57,9 @@ def __init__(
default_value: Union[Tensor, List],
nan_strategy: Union[str, float] = "error",
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
**kwargs: Dict[str, Any],
):
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)
super().__init__(compute_on_step=compute_on_step, **kwargs)
allowed_nan_strategy = ("error", "warn", "ignore")
if nan_strategy not in allowed_nan_strategy and not isinstance(nan_strategy, float):
raise ValueError(
Expand Down Expand Up @@ -128,14 +115,8 @@ class MaxMetric(BaseAggregator):
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.

dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
process_group:
Specify the process group on which synchronization is called.
dist_sync_fn:
Callback that performs the allgather operation on the metric state.
When `None`, DDP will be used to perform the allgather.
kwargs:
Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Raises:
ValueError:
Expand All @@ -154,18 +135,14 @@ def __init__(
self,
nan_strategy: Union[str, float] = "warn",
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
**kwargs: Dict[str, Any],
):
super().__init__(
"max",
-torch.tensor(float("inf")),
nan_strategy,
compute_on_step,
dist_sync_on_step,
process_group,
dist_sync_fn,
**kwargs,
)

def update(self, value: Union[float, Tensor]) -> None: # type: ignore
Expand Down Expand Up @@ -196,14 +173,8 @@ class MinMetric(BaseAggregator):
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.

dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
process_group:
Specify the process group on which synchronization is called.
dist_sync_fn:
Callback that performs the allgather operation on the metric state.
When `None`, DDP will be used to perform the allgather.
kwargs:
Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Raises:
ValueError:
Expand All @@ -222,18 +193,14 @@ def __init__(
self,
nan_strategy: Union[str, float] = "warn",
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
**kwargs: Dict[str, Any],
):
super().__init__(
"min",
torch.tensor(float("inf")),
nan_strategy,
compute_on_step,
dist_sync_on_step,
process_group,
dist_sync_fn,
**kwargs,
)

def update(self, value: Union[float, Tensor]) -> None: # type: ignore
Expand Down Expand Up @@ -264,14 +231,8 @@ class SumMetric(BaseAggregator):
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.

dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
process_group:
Specify the process group on which synchronization is called.
dist_sync_fn:
Callback that performs the allgather operation on the metric state.
When `None`, DDP will be used to perform the allgather.
kwargs:
Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Raises:
ValueError:
Expand All @@ -290,12 +251,14 @@ def __init__(
self,
nan_strategy: Union[str, float] = "warn",
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
**kwargs: Dict[str, Any],
):
super().__init__(
"sum", torch.tensor(0.0), nan_strategy, compute_on_step, dist_sync_on_step, process_group, dist_sync_fn
"sum",
torch.tensor(0.0),
nan_strategy,
compute_on_step,
**kwargs,
)

def update(self, value: Union[float, Tensor]) -> None: # type: ignore
Expand Down Expand Up @@ -325,14 +288,8 @@ class CatMetric(BaseAggregator):
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.

dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
process_group:
Specify the process group on which synchronization is called.
dist_sync_fn:
Callback that performs the allgather operation on the metric state.
When `None`, DDP will be used to perform the allgather.
kwargs:
Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Raises:
ValueError:
Expand All @@ -351,11 +308,9 @@ def __init__(
self,
nan_strategy: Union[str, float] = "warn",
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
**kwargs: Dict[str, Any],
):
super().__init__("cat", [], nan_strategy, compute_on_step, dist_sync_on_step, process_group, dist_sync_fn)
super().__init__("cat", [], nan_strategy, compute_on_step, **kwargs)

def update(self, value: Union[float, Tensor]) -> None: # type: ignore
"""Update state with data.
Expand Down Expand Up @@ -391,14 +346,8 @@ class MeanMetric(BaseAggregator):
.. deprecated:: v0.8
Argument has no use anymore and will be removed v0.9.

dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
process_group:
Specify the process group on which synchronization is called.
dist_sync_fn:
Callback that performs the allgather operation on the metric state.
When `None`, DDP will be used to perform the allgather.
kwargs:
Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Raises:
ValueError:
Expand All @@ -417,12 +366,14 @@ def __init__(
self,
nan_strategy: Union[str, float] = "warn",
compute_on_step: Optional[bool] = None,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
**kwargs: Dict[str, Any],
):
super().__init__(
"sum", torch.tensor(0.0), nan_strategy, compute_on_step, dist_sync_on_step, process_group, dist_sync_fn
"sum",
torch.tensor(0.0),
nan_strategy,
compute_on_step,
**kwargs,
)
self.add_state("weight", default=torch.tensor(0.0), dist_reduce_fx="sum")

Expand Down
Loading