Skip to content

Commit

Permalink
Fix moving keys to device in ResultCollection (#19814)
Browse files Browse the repository at this point in the history
Co-authored-by: Alexander Jipa <[email protected]>
  • Loading branch information
clumsy and azzhipa authored Jul 26, 2024
1 parent 2064887 commit b19eba3
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 10 deletions.
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added a flag `verbose` to the `seed_everything()` function ([#20108](https://github.com/Lightning-AI/pytorch-lightning/pull/20108))

-

### Changed

Expand Down Expand Up @@ -44,6 +45,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed an issue that would cause too many printouts of the seed info when using `seed_everything()` ([#20108](https://github.com/Lightning-AI/pytorch-lightning/pull/20108))

- Fixed `_LoggerConnector`'s `_ResultMetric` to move all registered keys to the device of the logged value if needed ([#19814](https://github.com/Lightning-AI/pytorch-lightning/issues/19814))



## [2.3.0] - 2024-06-13
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -400,26 +400,19 @@ def log(

# register logged value if it doesn't exist
if key not in self:
self.register_key(key, meta, value)
metric = _ResultMetric(meta, isinstance(value, Tensor))
self[key] = metric

# check the stored metadata and the current one match
elif meta != self[key].meta:
raise MisconfigurationException(
f"You called `self.log({name}, ...)` twice in `{fx}` with different arguments. This is not allowed"
)
self[key].to(value.device)

batch_size = self._extract_batch_size(self[key], batch_size, meta)
self.update_metrics(key, value, batch_size)

def register_key(self, key: str, meta: _Metadata, value: _VALUE) -> None:
"""Create one _ResultMetric object per value.
Value can be provided as a nested collection
"""
metric = _ResultMetric(meta, isinstance(value, Tensor)).to(value.device)
self[key] = metric

def update_metrics(self, key: str, value: _VALUE, batch_size: int) -> None:
result_metric = self[key]
# performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl`
Expand Down
21 changes: 21 additions & 0 deletions tests/tests_pytorch/trainer/logging_/test_logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from torchmetrics import Accuracy, MeanAbsoluteError, MeanSquaredError, MetricCollection
from torchmetrics import AveragePrecision as AvgPre

from tests_pytorch.helpers.runif import RunIf
from tests_pytorch.models.test_hooks import get_members


Expand Down Expand Up @@ -639,3 +640,23 @@ def test_result_collection_no_batch_size_extraction():
assert results["training_step.epoch_log_val"].value == log_val * batch_size
assert results["training_step.epoch_log_val"].cumulated_batch_size == batch_size
assert results["training_step.epoch_sum_log_val"].value == log_val


@RunIf(min_cuda_gpus=1)
def test_result_collection_changes_device():
"""Test that the keys in the ResultCollection are moved to the device together with the collection."""
results = _ResultCollection(training=True)
fx, name = "training_step", "step_log_val"
log_val = torch.tensor(7.0, device="cuda:0")

# same device as the original tensor
results.log(fx, name, log_val, on_step=True, on_epoch=False, reduce_fx="mean")
assert results[f"{fx}.{name}"].cumulated_batch_size.device == log_val.device

# moved to cpu
results.cpu()
assert results[f"{fx}.{name}"].cumulated_batch_size.device == torch.device("cpu")

# same device as the new tensor
results.log(fx, name, log_val, on_step=True, on_epoch=False, reduce_fx="mean")
assert results[f"{fx}.{name}"].cumulated_batch_size.device == log_val.device

0 comments on commit b19eba3

Please sign in to comment.