Skip to content

Commit

Permalink
[bug-fix] Metric reduction with Logging (#5150)
Browse files Browse the repository at this point in the history
* add test

* resolve bug

* udpate test

* wrongly copy / paste

* update test

* resolve a second bug

Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
tchaton and Ubuntu committed Dec 21, 2020
1 parent 0f36525 commit a48ca18
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 7 deletions.
7 changes: 5 additions & 2 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,11 @@ def _run_early_stopping_check(self, trainer, pl_module):
# when in dev debugging
trainer.dev_debugger.track_early_stopping_history(self, current)

if not isinstance(current, torch.Tensor):
current = torch.tensor(current, device=pl_module.device)
if current is not None:
if isinstance(current, Metric):
current = current.compute()
elif isinstance(current, numbers.Number):
current = torch.tensor(current, device=pl_module.device, dtype=torch.float)

if trainer.use_tpu and _TPU_AVAILABLE:
current = current.cpu()
Expand Down
10 changes: 8 additions & 2 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"""

import numbers
import os
import re
from copy import deepcopy
Expand All @@ -32,6 +33,8 @@

from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.metrics.metric import Metric
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -574,8 +577,11 @@ def _save_top_k_checkpoints(self, metrics, trainer, pl_module, filepath):
epoch = metrics.get("epoch")
step = metrics.get("step")

if not isinstance(current, torch.Tensor) and current is not None:
current = torch.tensor(current, device=pl_module.device)
if current is not None:
if isinstance(current, Metric):
current = current.compute()
elif isinstance(current, numbers.Number):
current = torch.tensor(current, device=pl_module.device, dtype=torch.float)

if self.check_monitor_top_k(current):
self._update_best_and_save(filepath, current, epoch, step, trainer, pl_module)
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,10 @@ def get_forked_metrics(self, add_dataloader_idx=False):
dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx)

if options['forked']:
result[dl_key] = self[k]
if isinstance(self[k], Metric):
result[dl_key] = self[k].compute().detach()
else:
result[dl_key] = self[k]

return result

Expand Down
49 changes: 47 additions & 2 deletions tests/trainer/logging_tests/test_train_loop_logging_1_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
from torch.utils.data import Dataset

import pytorch_lightning as pl
from pytorch_lightning import Trainer, callbacks
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import callbacks, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.core.lightning import LightningModule
from tests.base.boring_model import BoringModel, RandomDictDataset, RandomDictStringDataset
from tests.base.deterministic_model import DeterministicModel
Expand Down Expand Up @@ -817,3 +817,48 @@ def on_train_epoch_end(self, trainer, pl_module, outputs):
'on_epoch_end': 5,
'on_train_epoch_end': 6}
assert trainer.callback_metrics == expected


@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
def test_metric_are_properly_reduced(tmpdir):
class TestingModel(BoringModel):
def __init__(self, *args, **kwargs):
super().__init__()
self.train_acc = pl.metrics.Accuracy()
self.val_acc = pl.metrics.Accuracy()

def training_step(self, batch, batch_idx):
self.train_acc(torch.rand(1, 3, device=self.device), torch.randint(0, 2, (1,), device=self.device))
self.log('train_acc', self.train_acc, on_step=True, on_epoch=True)
return super().training_step(batch, batch_idx)

def validation_step(self, batch, batch_idx):
preds = torch.tensor(0, device=self.device)
targets = torch.tensor(1, device=self.device)
if batch_idx < 8:
targets = preds
self.val_acc(preds, targets)
self.log('val_acc', self.val_acc, on_step=True, on_epoch=True)
return super().validation_step(batch, batch_idx)

early_stop = EarlyStopping(monitor='val_acc', mode='max')

checkpoint = ModelCheckpoint(
monitor='val_acc',
save_last=True,
save_top_k=2,
mode='max',
)

model = TestingModel()
trainer = Trainer(
default_root_dir=tmpdir,
gpus=1,
max_epochs=2,
limit_train_batches=5,
limit_val_batches=32,
callbacks=[early_stop, checkpoint])
trainer.fit(model)

assert trainer.callback_metrics["val_acc"] == 8 / 32.
assert "train_acc" in trainer.callback_metrics

0 comments on commit a48ca18

Please sign in to comment.