Skip to content

Commit 67c6709

Browse files
Fix sync_all_reduce to consider update->compute->update case (#2803)
* Fix sync_all_reduce * Make _is_reduced no-op, add a test and a little improvement * Make _is_reduced no-op, add a test and a little improvement * Fix Mypy * Remove some asserts from test_loss & test_accuracy * Fix bug in tests, causing error when metric_device.type == XLA * Remove _is_reduced From now on, compute after compute does the whole thing again * Revert deleted assertions * Fix a bug in precision * Fix a mypy error * Update ignite/metrics/metric.py Co-authored-by: vfdev <[email protected]> * Revert a change in test_accuracy * Revert a change in test_accuracy exactly --------- Co-authored-by: vfdev <[email protected]>
1 parent 5d8d6bf commit 67c6709

File tree

5 files changed

+100
-41
lines changed

5 files changed

+100
-41
lines changed

Diff for: ignite/metrics/metric.py

+32-22
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from collections.abc import Mapping
33
from functools import wraps
44
from numbers import Number
5-
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, TYPE_CHECKING, Union
5+
from typing import Any, Callable, cast, Dict, Optional, Sequence, Tuple, TYPE_CHECKING, Union
66

77
import torch
88

@@ -214,7 +214,6 @@ def __init__(
214214
raise ValueError("Cannot create metric on an XLA device. Use device='cpu' instead.")
215215

216216
self._device = torch.device(device)
217-
self._is_reduced = False
218217
self.reset()
219218

220219
@abstractmethod
@@ -556,25 +555,37 @@ def another_wrapper(self: Metric, *args: Any, **kwargs: Any) -> Callable:
556555
"Decorator sync_all_reduce should be used on ignite.metric.Metric class methods only"
557556
)
558557
ws = idist.get_world_size()
559-
if len(attrs) > 0 and not self._is_reduced:
560-
if ws > 1:
561-
for attr in attrs:
562-
op_kwargs = {}
563-
if ":" in attr:
564-
attr, op = attr.split(":")
565-
valid_ops = ["MIN", "MAX", "SUM", "PRODUCT"]
566-
if op not in valid_ops:
567-
raise ValueError(f"Reduction operation is not valid (expected : {valid_ops}, got: {op}")
568-
op_kwargs["op"] = op
569-
t = getattr(self, attr, None)
570-
if t is not None:
571-
t = idist.all_reduce(t, **op_kwargs)
572-
self._is_reduced = True
573-
setattr(self, attr, t)
574-
else:
575-
self._is_reduced = True
576-
577-
return func(self, *args, **kwargs)
558+
unreduced_attrs = {}
559+
if len(attrs) > 0 and ws > 1:
560+
for attr in attrs:
561+
op_kwargs = {}
562+
if ":" in attr:
563+
attr, op = attr.split(":")
564+
valid_ops = ["MIN", "MAX", "SUM", "PRODUCT"]
565+
if op not in valid_ops:
566+
raise ValueError(f"Reduction operation is not valid (expected : {valid_ops}, got: {op}")
567+
op_kwargs["op"] = op
568+
if attr not in self.__dict__:
569+
raise ValueError(f"Metric {type(self)} has no attribute named `{attr}`.")
570+
t = getattr(self, attr)
571+
if not isinstance(t, (Number, torch.Tensor)):
572+
raise TypeError(
573+
"Attribute provided to sync_all_reduce should be a "
574+
f"number or tensor but `{attr}` has type {type(t)}"
575+
)
576+
unreduced_attrs[attr] = t
577+
# Here `clone` is necessary since `idist.all_reduce` modifies `t` inplace in the case
578+
# `t` is a tensor and its `device` is same as that of the process.
579+
# TODO: Remove this dual behavior of `all_reduce` to always either return a new tensor or
580+
# modify it in-place.
581+
t_reduced = idist.all_reduce(cast(float, t) if isinstance(t, Number) else t.clone(), **op_kwargs)
582+
setattr(self, attr, t_reduced)
583+
584+
result = func(self, *args, **kwargs)
585+
586+
for attr, value in unreduced_attrs.items():
587+
setattr(self, attr, value)
588+
return result
578589

579590
return another_wrapper
580591

@@ -594,7 +605,6 @@ def reinit__is_reduced(func: Callable) -> Callable:
594605
@wraps(func)
595606
def wrapper(self: Metric, *args: Any, **kwargs: Any) -> None:
596607
func(self, *args, **kwargs)
597-
self._is_reduced = False
598608
if "_result" in self.__dict__:
599609
self._result = None # type: ignore[attr-defined]
600610

Diff for: ignite/metrics/precision.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import ignite.distributed as idist
77
from ignite.exceptions import NotComputableError
88
from ignite.metrics.accuracy import _BaseClassification
9-
from ignite.metrics.metric import reinit__is_reduced
9+
from ignite.metrics.metric import reinit__is_reduced, sync_all_reduce
1010
from ignite.utils import to_onehot
1111

1212
__all__ = ["Precision"]
@@ -121,6 +121,7 @@ def reset(self) -> None:
121121

122122
super(_BasePrecisionRecall, self).reset()
123123

124+
@sync_all_reduce("_numerator", "_denominator")
124125
def compute(self) -> Union[torch.Tensor, float]:
125126

126127
# Return value of the metric for `average` options `'weighted'` and `'macro'` is computed as follows.
@@ -138,18 +139,13 @@ def compute(self) -> Union[torch.Tensor, float]:
138139
raise NotComputableError(
139140
f"{self.__class__.__name__} must have at least one example before it can be computed."
140141
)
141-
if not self._is_reduced:
142-
self._numerator = idist.all_reduce(self._numerator) # type: ignore[assignment]
143-
self._denominator = idist.all_reduce(self._denominator) # type: ignore[assignment]
144-
if self._average == "weighted":
145-
self._weight = idist.all_reduce(self._weight) # type: ignore[assignment]
146-
self._is_reduced: bool = True
147142

148143
fraction = self._numerator / (self._denominator + (self.eps if self._average != "samples" else 0))
149144

150145
if self._average == "weighted":
151-
sum_of_weights = cast(torch.Tensor, self._weight).sum() + self.eps
152-
return ((fraction @ self._weight) / sum_of_weights).item() # type: ignore
146+
_weight = idist.all_reduce(self._weight.clone()) # type: ignore[union-attr]
147+
sum_of_weights = cast(torch.Tensor, _weight).sum() + self.eps
148+
return ((fraction @ _weight) / sum_of_weights).item() # type: ignore
153149
elif self._average == "micro" or self._average == "samples":
154150
return cast(torch.Tensor, fraction).item()
155151
elif self._average == "macro":

Diff for: tests/ignite/metrics/test_accuracy.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -275,16 +275,18 @@ def _test(metric_device):
275275
acc._num_correct.device == metric_device
276276
), f"{type(acc._num_correct.device)}:{acc._num_correct.device} vs {type(metric_device)}:{metric_device}"
277277

278+
n = acc._num_examples
279+
assert n == y.numel() / y.size(dim=1)
280+
278281
# gather y_pred, y
279282
y_pred = idist.all_gather(y_pred)
280283
y = idist.all_gather(y)
281284

282285
np_y_pred = to_numpy_multilabel(y_pred.cpu()) # (N, C, H, W, ...) -> (N * H * W ..., C)
283286
np_y = to_numpy_multilabel(y.cpu()) # (N, C, H, W, ...) -> (N * H * W ..., C)
284287
assert acc._type == "multilabel"
285-
n = acc._num_examples
286288
res = acc.compute()
287-
assert n * idist.get_world_size() == acc._num_examples
289+
assert n == acc._num_examples
288290
assert isinstance(res, float)
289291
assert accuracy_score(np_y, np_y_pred) == pytest.approx(res)
290292

@@ -298,6 +300,9 @@ def _test(metric_device):
298300
acc._num_correct.device == metric_device
299301
), f"{type(acc._num_correct.device)}:{acc._num_correct.device} vs {type(metric_device)}:{metric_device}"
300302

303+
n = acc._num_examples
304+
assert n == y.numel() / y.size(dim=1)
305+
301306
# gather y_pred, y
302307
y_pred = idist.all_gather(y_pred)
303308
y = idist.all_gather(y)
@@ -306,14 +311,13 @@ def _test(metric_device):
306311
np_y = to_numpy_multilabel(y.cpu()) # (N, C, H, W, ...) -> (N * H * W ..., C)
307312

308313
assert acc._type == "multilabel"
309-
n = acc._num_examples
310314
res = acc.compute()
311-
assert n * idist.get_world_size() == acc._num_examples
315+
assert n == acc._num_examples
312316
assert isinstance(res, float)
313317
assert accuracy_score(np_y, np_y_pred) == pytest.approx(res)
314318
# check that result is not changed
315319
res = acc.compute()
316-
assert n * idist.get_world_size() == acc._num_examples
320+
assert n == acc._num_examples
317321
assert isinstance(res, float)
318322
assert accuracy_score(np_y, np_y_pred) == pytest.approx(res)
319323

@@ -334,6 +338,9 @@ def _test(metric_device):
334338
acc._num_correct.device == metric_device
335339
), f"{type(acc._num_correct.device)}:{acc._num_correct.device} vs {type(metric_device)}:{metric_device}"
336340

341+
n = acc._num_examples
342+
assert n == y.numel() / y.size(dim=1)
343+
337344
# gather y_pred, y
338345
y_pred = idist.all_gather(y_pred)
339346
y = idist.all_gather(y)
@@ -342,9 +349,8 @@ def _test(metric_device):
342349
np_y = to_numpy_multilabel(y.cpu()) # (N, C, L, ...) -> (N * L ..., C)
343350

344351
assert acc._type == "multilabel"
345-
n = acc._num_examples
346352
res = acc.compute()
347-
assert n * idist.get_world_size() == acc._num_examples
353+
assert n == acc._num_examples
348354
assert isinstance(res, float)
349355
assert accuracy_score(np_y, np_y_pred) == pytest.approx(res)
350356

Diff for: tests/ignite/metrics/test_loss.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def _test(metric_device, y_test_1, y_test_2):
148148
n = loss._num_examples
149149
assert n == len(y)
150150
res = loss.compute()
151-
assert n * idist.get_world_size() == loss._num_examples
151+
assert n == loss._num_examples
152152

153153
y_pred = idist.all_gather(y_pred)
154154
y = idist.all_gather(y)
@@ -160,7 +160,7 @@ def _test(metric_device, y_test_1, y_test_2):
160160
loss.update((y_pred, y))
161161
n = loss._num_examples
162162
res = loss.compute()
163-
assert n * idist.get_world_size() == loss._num_examples
163+
assert n == loss._num_examples
164164

165165
y_pred = idist.all_gather(y_pred)
166166
y = idist.all_gather(y)

Diff for: tests/ignite/metrics/test_metric.py

+48-1
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,29 @@ def update(self, output):
535535
pass
536536

537537

538+
def _test_compute_with_sync_all_reduce_doesnt_change_attributes(device):
539+
class DummyMetric3(Metric):
540+
@reinit__is_reduced
541+
def reset(self):
542+
self.a = torch.tensor(0.0, device=self._device)
543+
self.b = 0.0
544+
545+
def update(self, output):
546+
self.a += torch.tensor(1.0)
547+
self.b += 1.0
548+
549+
@sync_all_reduce("a", "b")
550+
def compute(self):
551+
return self.a.item(), self.b
552+
553+
metric_device = device if torch.device(device).type != "xla" else "cpu"
554+
metric = DummyMetric3(device=metric_device)
555+
metric.update(None)
556+
assert metric.a.item() == metric.b == 1.0
557+
metric.compute()
558+
assert metric.a.item() == metric.b == 1.0
559+
560+
538561
def _test_invalid_sync_all_reduce(device):
539562
class InvalidMetric(Metric):
540563
@reinit__is_reduced
@@ -543,6 +566,7 @@ def reset(self):
543566
self.c = 0.0
544567
self.n = 0
545568
self.m = -1
569+
self.d = "a string"
546570

547571
def compute(self):
548572
pass
@@ -566,6 +590,14 @@ def invalid_reduction_op_3(self):
566590
def invalid_reduction_op_4(self):
567591
pass
568592

593+
@sync_all_reduce("missingattr")
594+
def invalid_reduction_op_5(self):
595+
pass
596+
597+
@sync_all_reduce("d")
598+
def invalid_reduction_op_6(self):
599+
pass
600+
569601
metric_device = device if torch.device(device).type != "xla" else "cpu"
570602
m = InvalidMetric(device=metric_device)
571603
m.reset()
@@ -583,6 +615,14 @@ def invalid_reduction_op_4(self):
583615
with pytest.raises(ValueError, match=r"Reduction operation is not valid"):
584616
m.invalid_reduction_op_4()
585617

618+
with pytest.raises(ValueError, match=r"has no attribute named `missingattr`."):
619+
m.invalid_reduction_op_5()
620+
621+
with pytest.raises(
622+
TypeError, match=r"Attribute provided to sync_all_reduce should be a number or tensor but `d`"
623+
):
624+
m.invalid_reduction_op_6()
625+
586626

587627
def _test_distrib_sync_all_reduce_decorator(device):
588628
class DummyMetric(Metric):
@@ -647,7 +687,7 @@ def update(self, output):
647687
m = DummyMetric(device=metric_device)
648688
m.update(None)
649689
m.compute()
650-
# check if can call compute multiple times without all reduce invocation
690+
# check if attributes are restored to their original values after previous `compute`
651691
m.compute()
652692

653693

@@ -664,6 +704,7 @@ def test_distrib_nccl_gpu(distributed_context_single_node_nccl):
664704
device = idist.device()
665705
_test_distrib_sync_all_reduce_decorator(device)
666706
_test_invalid_sync_all_reduce(device)
707+
_test_compute_with_sync_all_reduce_doesnt_change_attributes(device)
667708

668709

669710
@pytest.mark.distributed
@@ -673,6 +714,7 @@ def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo):
673714
device = idist.device()
674715
_test_distrib_sync_all_reduce_decorator(device)
675716
_test_invalid_sync_all_reduce(device)
717+
_test_compute_with_sync_all_reduce_doesnt_change_attributes(device)
676718

677719

678720
@pytest.mark.distributed
@@ -685,6 +727,7 @@ def test_distrib_hvd(gloo_hvd_executor):
685727

686728
gloo_hvd_executor(_test_distrib_sync_all_reduce_decorator, (device,), np=nproc, do_init=True)
687729
gloo_hvd_executor(_test_invalid_sync_all_reduce, (device,), np=nproc, do_init=True)
730+
gloo_hvd_executor(_test_compute_with_sync_all_reduce_doesnt_change_attributes, (device,), np=nproc, do_init=True)
688731

689732

690733
@pytest.mark.multinode_distributed
@@ -695,6 +738,7 @@ def test_multinode_distrib_gloo_cpu_or_gpu(distributed_context_multi_node_gloo):
695738
device = idist.device()
696739
_test_distrib_sync_all_reduce_decorator(device)
697740
_test_invalid_sync_all_reduce(device)
741+
_test_compute_with_sync_all_reduce_doesnt_change_attributes(device)
698742

699743

700744
@pytest.mark.multinode_distributed
@@ -705,6 +749,7 @@ def test_multinode_distrib_nccl_gpu(distributed_context_multi_node_nccl):
705749
device = idist.device()
706750
_test_distrib_sync_all_reduce_decorator(device)
707751
_test_invalid_sync_all_reduce(device)
752+
_test_compute_with_sync_all_reduce_doesnt_change_attributes(device)
708753

709754

710755
@pytest.mark.tpu
@@ -715,13 +760,15 @@ def test_distrib_single_device_xla():
715760
_test_distrib_sync_all_reduce_decorator(device)
716761
_test_creating_on_xla_fails(device)
717762
_test_invalid_sync_all_reduce(device)
763+
_test_compute_with_sync_all_reduce_doesnt_change_attributes(device)
718764

719765

720766
def _test_distrib_xla_nprocs(index):
721767
device = idist.device()
722768
_test_distrib_sync_all_reduce_decorator(device)
723769
_test_creating_on_xla_fails(device)
724770
_test_invalid_sync_all_reduce(device)
771+
_test_compute_with_sync_all_reduce_doesnt_change_attributes(device)
725772

726773

727774
@pytest.mark.tpu

0 commit comments

Comments
 (0)