Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug-fix] Metric reduction with Logging #5150

Merged
merged 8 commits into from
Dec 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,16 @@
Monitor a metric and stop training when it stops improving.

"""
import numbers
import os

import numpy as np
import torch

from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn, TPU_AVAILABLE
from pytorch_lightning.metrics.metric import Metric
from pytorch_lightning.utilities import TPU_AVAILABLE, rank_zero_info, rank_zero_warn


class EarlyStopping(Callback):
Expand Down Expand Up @@ -201,8 +203,11 @@ def _run_early_stopping_check(self, trainer, pl_module):
# when in dev debugging
trainer.dev_debugger.track_early_stopping_history(self, current)

if not isinstance(current, torch.Tensor):
current = torch.tensor(current, device=pl_module.device)
if current is not None:
if isinstance(current, Metric):
current = current.compute()
elif isinstance(current, numbers.Number):
current = torch.tensor(current, device=pl_module.device, dtype=torch.float)

if trainer.use_tpu and TPU_AVAILABLE:
current = current.cpu()
Expand Down
11 changes: 8 additions & 3 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

"""

import numbers
import os
import re
from copy import deepcopy
Expand All @@ -32,8 +33,9 @@

from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.metrics.metric import Metric
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException

Expand Down Expand Up @@ -580,8 +582,11 @@ def _save_top_k_checkpoints(self, trainer, pl_module, metrics):
epoch = metrics.get("epoch")
step = metrics.get("step")

if not isinstance(current, torch.Tensor) and current is not None:
current = torch.tensor(current, device=pl_module.device)
if current is not None:
if isinstance(current, Metric):
current = current.compute()
elif isinstance(current, numbers.Number):
current = torch.tensor(current, device=pl_module.device, dtype=torch.float)

if self.check_monitor_top_k(current):
self._update_best_and_save(current, epoch, step, trainer, pl_module, metrics)
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,10 @@ def get_forked_metrics(self, add_dataloader_idx=False):
dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx)

if options['forked']:
result[dl_key] = self[k]
if isinstance(self[k], Metric):
result[dl_key] = self[k].compute().detach()
else:
result[dl_key] = self[k]

return result

Expand Down
49 changes: 47 additions & 2 deletions tests/trainer/logging_tests/test_train_loop_logging_1_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
from torch.utils.data import Dataset

import pytorch_lightning as pl
from pytorch_lightning import Trainer, callbacks
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import callbacks, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.core.lightning import LightningModule
from tests.base.boring_model import BoringModel, RandomDictDataset, RandomDictStringDataset
from tests.base.deterministic_model import DeterministicModel
Expand Down Expand Up @@ -771,3 +771,48 @@ def on_train_epoch_end(self, *_):
trainer.fit(model)
assert model.epoch_end_called
assert model.on_train_epoch_end_called


@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
def test_metric_are_properly_reduced(tmpdir):
class TestingModel(BoringModel):
def __init__(self, *args, **kwargs):
super().__init__()
self.train_acc = pl.metrics.Accuracy()
self.val_acc = pl.metrics.Accuracy()

def training_step(self, batch, batch_idx):
self.train_acc(torch.rand(1, 3, device=self.device), torch.randint(0, 2, (1,), device=self.device))
self.log('train_acc', self.train_acc, on_step=True, on_epoch=True)
return super().training_step(batch, batch_idx)

def validation_step(self, batch, batch_idx):
preds = torch.tensor(0, device=self.device)
targets = torch.tensor(1, device=self.device)
if batch_idx < 8:
targets = preds
self.val_acc(preds, targets)
self.log('val_acc', self.val_acc, on_step=True, on_epoch=True)
return super().validation_step(batch, batch_idx)

early_stop = EarlyStopping(monitor='val_acc', mode='max')

checkpoint = ModelCheckpoint(
monitor='val_acc',
save_last=True,
save_top_k=2,
mode='max',
)

model = TestingModel()
trainer = Trainer(
default_root_dir=tmpdir,
gpus=1,
max_epochs=2,
limit_train_batches=5,
limit_val_batches=32,
callbacks=[early_stop, checkpoint])
trainer.fit(model)

assert trainer.callback_metrics["val_acc"] == 8 / 32.
assert "train_acc" in trainer.callback_metrics