Skip to content

Commit

Permalink
Fix list syncronization with partly empty lists (#2468)
Browse files Browse the repository at this point in the history
* implementation + tests

* changelog

* fix tests

* only for newer versions

---------

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
3 people authored Apr 22, 2024
1 parent 9d04667 commit cd7ccfc
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed axis names with Precision-Recall curve ([#2462](https://github.com/Lightning-AI/torchmetrics/pull/2462))


- Fixed list synchronization with partly empty lists ([#2468](https://github.com/Lightning-AI/torchmetrics/pull/2468))


- Fixed memory leak in metrics using list states ([#2492](https://github.com/Lightning-AI/torchmetrics/pull/2492))


Expand Down
10 changes: 10 additions & 0 deletions src/torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from torchmetrics.utilities.distributed import gather_all_tensors
from torchmetrics.utilities.exceptions import TorchMetricsUserError
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val
from torchmetrics.utilities.prints import rank_zero_warn

Expand Down Expand Up @@ -438,6 +439,15 @@ def _sync_dist(self, dist_sync_fn: Callable = gather_all_tensors, process_group:
if reduction_fn == dim_zero_cat and isinstance(input_dict[attr], list) and len(input_dict[attr]) > 1:
input_dict[attr] = [dim_zero_cat(input_dict[attr])]

# cornor case in distributed settings where a rank have not received any data, create empty to concatenate
if (
_TORCH_GREATER_EQUAL_2_1
and reduction_fn == dim_zero_cat
and isinstance(input_dict[attr], list)
and len(input_dict[attr]) == 0
):
input_dict[attr] = [torch.tensor([], device=self.device, dtype=self.dtype)]

output_dict = apply_to_collection(
input_dict,
Tensor,
Expand Down
20 changes: 19 additions & 1 deletion tests/unittests/bases/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torchmetrics import Metric
from torchmetrics.utilities.distributed import gather_all_tensors
from torchmetrics.utilities.exceptions import TorchMetricsUserError
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1

from unittests import NUM_PROCESSES
from unittests._helpers import seed_all
Expand Down Expand Up @@ -269,11 +270,28 @@ def test_sync_on_compute(sync_on_compute, test_func):
def _test_sync_with_empty_lists(rank):
dummy = DummyListMetric()
val = dummy.compute()
assert val == []
assert torch.allclose(val, tensor([]))


@pytest.mark.DDP()
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_2_1, reason="test only works on newer torch versions")
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
def test_sync_with_empty_lists():
"""Test that synchronization of states can be enabled and disabled for compute."""
pytest.pool.map(_test_sync_with_empty_lists, range(NUM_PROCESSES))


def _test_sync_with_unequal_size_lists(rank):
"""Test that synchronization of list states work even when some ranks have not received any data yet."""
dummy = DummyListMetric()
if rank == 0:
dummy.update(torch.zeros(2))
assert torch.all(dummy.compute() == tensor([0.0, 0.0]))


@pytest.mark.DDP()
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_2_1, reason="test only works on newer torch versions")
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
def test_sync_with_unequal_size_lists():
"""Test that synchronization of states can be enabled and disabled for compute."""
pytest.pool.map(_test_sync_with_unequal_size_lists, range(NUM_PROCESSES))

0 comments on commit cd7ccfc

Please sign in to comment.