Skip to content

Commit

Permalink
Fix min/max logging default value (#11310)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholi <[email protected]>
  • Loading branch information
2 people authored and lexierule committed Jan 5, 2022
1 parent 4398db2 commit b707c67
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed

- Fixed `LightningCLI` race condition while saving the config ([#11199](https://github.com/PyTorchLightning/pytorch-lightning/pull/11199))
- Fixed the default value used with `log(reduce_fx=min|max)` ([#11310](https://github.com/PyTorchLightning/pytorch-lightning/pull/11310))
- Fixed data fetcher selection ([#11294](https://github.com/PyTorchLightning/pytorch-lightning/pull/11294))
- Fixed a race condition that could result in incorrect (zero) values being observed in prediction writer callbacks ([#11288](https://github.com/PyTorchLightning/pytorch-lightning/pull/11288))
- Fixed dataloaders not getting reloaded the correct amount of times when setting `reload_dataloaders_every_n_epochs` and `check_val_every_n_epoch` ([#10948](https://github.com/PyTorchLightning/pytorch-lightning/pull/10948))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,14 @@ def __init__(self, metadata: _Metadata, is_tensor: bool) -> None:
self.meta = metadata
self.has_reset = False
if is_tensor:
if metadata.is_max_reduction:
default = float("-inf")
elif metadata.is_min_reduction:
default = float("inf")
else:
default = 0.0
# do not set a dtype in case the default dtype was changed
self.add_state("value", torch.tensor(0.0), dist_reduce_fx=torch.sum)
self.add_state("value", torch.tensor(default), dist_reduce_fx=torch.sum)
if self.meta.is_mean_reduction:
self.add_state("cumulated_batch_size", torch.tensor(0), dist_reduce_fx=torch.sum)

Expand Down
13 changes: 11 additions & 2 deletions tests/core/test_metric_result_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,14 +575,14 @@ def test_metric_result_respects_dtype(floating_dtype):
assert rm.cumulated_batch_size.dtype == fixed_dtype

# two fixed point numbers - should be converted
value, batch_size = torch.tensor(2), torch.tensor(3)
value, batch_size = torch.tensor(2), 3
assert value.dtype == fixed_dtype
with pytest.warns(
UserWarning, match=rf"`self.log\('bar', ...\)` in your `foo` .* Converting it to {floating_dtype}"
):
rm.update(value, batch_size)
# floating and fixed
rm.update(torch.tensor(4.0), torch.tensor(5))
rm.update(torch.tensor(4.0), 5)

total = rm.compute()

Expand All @@ -591,3 +591,12 @@ def test_metric_result_respects_dtype(floating_dtype):

# restore to avoid impacting other tests
torch.set_default_dtype(torch.float)


@pytest.mark.parametrize(["reduce_fx", "expected"], [(max, -2), (min, 2)])
def test_result_metric_max_min(reduce_fx, expected):
metadata = _Metadata("foo", "bar", reduce_fx=reduce_fx)
metadata.sync = _Sync()
rm = ResultMetric(metadata, is_tensor=True)
rm.update(torch.tensor(expected), 1)
assert rm.compute() == expected

0 comments on commit b707c67

Please sign in to comment.