Skip to content

Commit

Permalink
Disable batch_size extraction for torchmetric instances (#10815)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholi <[email protected]>
  • Loading branch information
rohitgr7 and carmocca authored Nov 30, 2021
1 parent 32e6522 commit 1437be5
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 43 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a consolidation error in Lite when attempting to save the state dict of a sharded optimizer ([#10746](https://github.com/PyTorchLightning/pytorch-lightning/pull/10746))


- Fixed the default logging level for batch hooks associated with training from `on_step=False, on_epoch=True` to `on_step=True, on_epoch=False` ([#10756](https://github.com/PyTorchLightning/pytorch-lightning/pull/10756))
- Disabled batch_size extraction for torchmetric instances because they accumulate the metrics internally ([#10815](https://github.com/PyTorchLightning/pytorch-lightning/pull/10815))


- Fixed the default logging level for batch hooks associated with training from `on_step=False, on_epoch=True` to `on_step=True, on_epoch=False` ([#10756](https://github.com/PyTorchLightning/pytorch-lightning/pull/10756))


-
Expand Down
16 changes: 9 additions & 7 deletions pytorch_lightning/trainer/connectors/logger_connector/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,12 +338,13 @@ class ResultMetricCollection(dict):
with the same metadata.
"""

def __init__(self, *args: Any) -> None:
super().__init__(*args)

@property
def meta(self) -> _Metadata:
return list(self.values())[0].meta
return next(iter(self.values())).meta

@property
def has_tensor(self) -> bool:
return any(v.is_tensor for v in self.values())

def __getstate__(self, drop_value: bool = False) -> dict:
def getstate(item: ResultMetric) -> dict:
Expand Down Expand Up @@ -403,7 +404,7 @@ def append_fn(v: ResultMetric) -> None:
apply_to_collection(list(self.values()), ResultMetric, append_fn)
return o

def _extract_batch_size(self, batch_size: Optional[int], meta: _Metadata) -> int:
def _extract_batch_size(self, value: _METRIC_COLLECTION, batch_size: Optional[int], meta: _Metadata) -> int:
# check if we have extracted the batch size already
if batch_size is None:
batch_size = self.batch_size
Expand All @@ -412,7 +413,8 @@ def _extract_batch_size(self, batch_size: Optional[int], meta: _Metadata) -> int
return batch_size

batch_size = 1
if self.batch is not None and meta.on_epoch and meta.is_mean_reduction:
is_tensor = value.is_tensor if isinstance(value, ResultMetric) else value.has_tensor
if self.batch is not None and is_tensor and meta.on_epoch and meta.is_mean_reduction:
batch_size = extract_batch_size(self.batch)
self.batch_size = batch_size

Expand Down Expand Up @@ -477,7 +479,7 @@ def log(
f"You called `self.log({name}, ...)` twice in `{fx}` with different arguments. This is not allowed"
)

batch_size = self._extract_batch_size(batch_size, meta)
batch_size = self._extract_batch_size(self[key], batch_size, meta)
self.update_metrics(key, value, batch_size)

def register_key(self, key: str, meta: _Metadata, value: _METRIC_COLLECTION) -> None:
Expand Down
50 changes: 49 additions & 1 deletion tests/trainer/logging_/test_logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pytest
import torch
from torch.utils.data import DataLoader
from torchmetrics import Accuracy, AveragePrecision
from torchmetrics import Accuracy, AveragePrecision, MeanAbsoluteError, MeanSquaredError

from pytorch_lightning import LightningModule
from pytorch_lightning.callbacks.base import Callback
Expand Down Expand Up @@ -640,3 +640,51 @@ def training_step(self, batch, batch_idx):

# should not get overridden if logged manually
assert trainer.logged_metrics == {"epoch": -1}


def test_result_collection_batch_size_extraction():
fx_name = "training_step"
log_val = torch.tensor(7.0)

results = ResultCollection(training=True, device="cpu")
results.batch = torch.randn(1, 4)
train_mse = MeanSquaredError()
train_mse(torch.randn(4, 5), torch.randn(4, 5))
results.log(fx_name, "train_logs", {"mse": train_mse, "log_val": log_val}, on_step=False, on_epoch=True)
assert results.batch_size == 1
assert isinstance(results["training_step.train_logs"]["mse"].value, MeanSquaredError)
assert results["training_step.train_logs"]["log_val"].value == log_val

results = ResultCollection(training=True, device="cpu")
results.batch = torch.randn(1, 4)
results.log(fx_name, "train_log", log_val, on_step=False, on_epoch=True)
assert results.batch_size == 1
assert results["training_step.train_log"].value == log_val
assert results["training_step.train_log"].cumulated_batch_size == 1


def test_result_collection_no_batch_size_extraction():
results = ResultCollection(training=True, device="cpu")
results.batch = torch.randn(1, 4)
fx_name = "training_step"
batch_size = 10
log_val = torch.tensor(7.0)

train_mae = MeanAbsoluteError()
train_mae(torch.randn(4, 5), torch.randn(4, 5))
train_mse = MeanSquaredError()
train_mse(torch.randn(4, 5), torch.randn(4, 5))
results.log(fx_name, "step_log_val", log_val, on_step=True, on_epoch=False)
results.log(fx_name, "epoch_log_val", log_val, on_step=False, on_epoch=True, batch_size=batch_size)
results.log(fx_name, "epoch_sum_log_val", log_val, on_step=True, on_epoch=True, reduce_fx="sum")
results.log(fx_name, "train_mae", train_mae, on_step=True, on_epoch=False)
results.log(fx_name, "train_mse", {"mse": train_mse}, on_step=True, on_epoch=False)

assert results.batch_size is None
assert isinstance(results["training_step.train_mse"]["mse"].value, MeanSquaredError)
assert isinstance(results["training_step.train_mae"].value, MeanAbsoluteError)
assert results["training_step.step_log_val"].value == log_val
assert results["training_step.step_log_val"].cumulated_batch_size == 0
assert results["training_step.epoch_log_val"].value == log_val * batch_size
assert results["training_step.epoch_log_val"].cumulated_batch_size == batch_size
assert results["training_step.epoch_sum_log_val"].value == log_val
34 changes: 0 additions & 34 deletions tests/trainer/logging_/test_train_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.deprecated_api import no_warning_call
from tests.helpers.boring_model import BoringModel, RandomDataset, RandomDictDataset
from tests.helpers.runif import RunIf

Expand Down Expand Up @@ -746,36 +745,3 @@ def validation_epoch_end(self, *_) -> None:
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)


def test_no_batch_size_extraction_with_specifying_explictly(tmpdir):
batch_size = BoringModel().train_dataloader().batch_size + 1
fast_dev_run = 2
log_val = 7

class CustomBoringModel(BoringModel):
def on_before_batch_transfer(self, batch, *args, **kwargs):
# This is an ambiguous batch which have multiple potential batch sizes
if self.trainer.training:
batch = {"batch1": torch.randn(batch_size, 10), "batch2": batch}
return batch

def training_step(self, batch, batch_idx):
self.log("step_log_val", log_val, on_epoch=False)
self.log("epoch_log_val", log_val, batch_size=batch_size, on_step=False, on_epoch=True)
self.log("epoch_sum_log_val", log_val, on_epoch=True, reduce_fx="sum")
return super().training_step(batch["batch2"], batch_idx)

def on_train_epoch_end(self, *args, **kwargs):
results = self.trainer._results
assert results["training_step.step_log_val"].value == log_val
assert results["training_step.step_log_val"].cumulated_batch_size == 0
assert results["training_step.epoch_log_val"].value == log_val * batch_size * fast_dev_run
assert results["training_step.epoch_log_val"].cumulated_batch_size == batch_size * fast_dev_run
assert results["training_step.epoch_sum_log_val"].value == log_val * fast_dev_run

model = CustomBoringModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=fast_dev_run)

with no_warning_call(match="Trying to infer the `batch_size`"):
trainer.fit(model)

0 comments on commit 1437be5

Please sign in to comment.