Skip to content

Commit f293b9b

Browse files
bobakfbfacebook-github-bot
authored andcommitted
Fix warning in aggregation.mean (#187)
Summary: Pull Request resolved: #187 This diff fixes the incorrect warning when running `mean.compute()` when the mean is exactly 0. Instead of checking for the weighted sum of elements to be 0, we instead check for the total sum of weights to be zero (meaning that the average can be 0 without error, but we throw a warning when dividing by zero) We also update the error message to reflect that the issue is no weight has been accumulated, since it is possible to call this function with only 0 weights. Addresses: #185 Reviewed By: JKSenthil Differential Revision: D50806243 fbshipit-source-id: 04d75826ae8c1a24cc3718967d86bdd982081538
1 parent 235aa26 commit f293b9b

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

tests/metrics/aggregation/test_mean.py

+4
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ def test_mean_class_compute_without_update(self) -> None:
5454
metric = Mean()
5555
self.assertEqual(metric.compute(), torch.tensor(0.0, dtype=torch.float64))
5656

57+
metric = Mean()
58+
metric.update(torch.tensor([0.0, 0.0]), weight=0)
59+
self.assertEqual(metric.compute(), torch.tensor(0.0, dtype=torch.float64))
60+
5761
def test_mean_class_update_input_valid_weight(self) -> None:
5862
update_value = [
5963
torch.rand(BATCH_SIZE),

torcheval/metrics/aggregation/mean.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,11 @@ def __init__(
5555
device: Optional[torch.device] = None,
5656
) -> None:
5757
super().__init__(device=device)
58+
# weighted sum of values over the entire state
5859
self._add_state(
5960
"weighted_sum", torch.tensor(0.0, device=self.device, dtype=torch.float64)
6061
)
62+
# sum total of weights over the entire state
6163
self._add_state(
6264
"weights", torch.tensor(0.0, device=self.device, dtype=torch.float64)
6365
)
@@ -82,9 +84,9 @@ def update(
8284
ValueError: If value of weight is neither a ``float`` nor a ``int'' nor a ``torch.Tensor`` that matches the input tensor size.
8385
"""
8486

85-
weighted_sum, weights = _mean_update(input, weight)
87+
weighted_sum, net_weight = _mean_update(input, weight)
8688
self.weighted_sum += weighted_sum
87-
self.weights += weights
89+
self.weights += net_weight
8890
return self
8991

9092
@torch.inference_mode()
@@ -93,8 +95,10 @@ def compute(self: TMean) -> torch.Tensor:
9395
If no calls to ``update()`` are made before ``compute()`` is called,
9496
the function throws a warning and returns 0.0.
9597
"""
96-
if not self.weighted_sum:
97-
logging.warning("No calls to update() have been made - returning 0.0")
98+
if not torch.is_nonzero(self.weights):
99+
logging.warning(
100+
"There is no weight for the average, no samples with weight have been added (did you ever run update()?)- returning 0.0"
101+
)
98102
return torch.tensor(0.0, dtype=torch.float64)
99103
return self.weighted_sum / self.weights
100104

0 commit comments

Comments
 (0)