Skip to content

Commit 21cfdf6

Browse files
ref: result 1/n (make monitor default to checkpoint_on to simplify re… (#3571)
* ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: ananthsub <[email protected]> * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * force crash when max_epochs < epochs in a checkpoint Co-authored-by: ananthsub <[email protected]>
1 parent 2775389 commit 21cfdf6

23 files changed

+170
-115
lines changed

docs/source/early_stopping.rst

+4-3
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@ If you do this repeatedly, for every epoch you had originally requested, then th
1818

1919
Default Epoch End Callback Behavior
2020
-----------------------------------
21-
By default early stopping will be enabled if `'val_loss'`
22-
is found in :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end`'s
23-
return dict. Otherwise training will proceed with early stopping disabled.
21+
By default early stopping will be enabled if the `early_stop_on` key in the EvalResult object is used
22+
in either the :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step` method or
23+
the :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end` method.
24+
2425

2526
----------
2627

docs/source/introduction_guide.rst

+3
Original file line numberDiff line numberDiff line change
@@ -683,6 +683,9 @@ Since the `validation_step` processes a single batch, use the `EvalResult` to lo
683683
.. code-block:: python
684684
685685
def validation_step(self, batch, batch_idx):
686+
loss = MSE_loss(...)
687+
688+
# loss is a tensor. The Checkpoint Callback is monitoring 'checkpoint_on'
686689
result = pl.EvalResult(checkpoint_on=loss)
687690
result.log('val_loss', loss)
688691

docs/source/lightning_module.rst

+2
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,8 @@ For cases like production, you might want to iterate different models inside a L
288288
y_hat = self.model(x)
289289
loss = F.cross_entropy(y_hat, y)
290290
acc = FM.accuracy(y_hat, y)
291+
292+
# loss is tensor. The Checkpoint Callback is monitoring 'checkpoint_on'
291293
result = pl.EvalResult(checkpoint_on=loss)
292294
result.log_dict({'val_acc': acc, 'val_loss': loss})
293295
return result

docs/source/weights_loading.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ To modify the behavior of checkpointing pass in your own callback.
4545
filepath=os.getcwd(),
4646
save_top_k=1,
4747
verbose=True,
48-
monitor='val_loss',
48+
monitor='checkpoint_on',
4949
mode='min',
5050
prefix=''
5151
)

pytorch_lightning/callbacks/early_stopping.py

+3-29
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,7 @@ class EarlyStopping(Callback):
4242
r"""
4343
4444
Args:
45-
monitor: quantity to be monitored. Default: ``'val_loss'``.
46-
.. note:: Has no effect when using `EvalResult` or `TrainResult`
45+
monitor: quantity to be monitored. Default: ``'early_stop_on'``.
4746
min_delta: minimum change in the monitored quantity
4847
to qualify as an improvement, i.e. an absolute
4948
change of less than `min_delta`, will count as no
@@ -73,7 +72,7 @@ class EarlyStopping(Callback):
7372
'max': torch.gt,
7473
}
7574

76-
def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience: int = 3,
75+
def __init__(self, monitor: str = 'early_stop_on', min_delta: float = 0.0, patience: int = 3,
7776
verbose: bool = False, mode: str = 'auto', strict: bool = True):
7877
super().__init__()
7978
self.monitor = monitor
@@ -150,16 +149,6 @@ def on_validation_epoch_end(self, trainer, pl_module):
150149
if trainer.running_sanity_check:
151150
return
152151

153-
self.__warn_deprecated_monitor_key()
154-
155-
val_es_key = 'val_early_stop_on'
156-
if trainer.logger_connector.callback_metrics.get(val_es_key) is not None:
157-
self.monitor = val_es_key
158-
159-
# disable strict checking when using structured results
160-
if val_es_key in trainer.logger_connector.callback_metrics:
161-
self.strict = False
162-
163152
if self._validate_condition_metric(trainer.logger_connector.callback_metrics):
164153
# turn off early stopping in on_train_epoch_end
165154
self.based_on_eval_results = True
@@ -171,29 +160,14 @@ def on_train_epoch_end(self, trainer, pl_module):
171160

172161
# early stopping can also work in the train loop when there is no val loop
173162
should_check_early_stop = False
174-
# early_stop_on takes precedence over monitor key
175-
train_es_key = 'early_stop_on'
176-
if trainer.logger_connector.callback_metrics.get(train_es_key, None) is not None:
177-
self.monitor = train_es_key
178-
should_check_early_stop = True
163+
179164
# fallback to monitor key in result dict
180165
if trainer.logger_connector.callback_metrics.get(self.monitor, None) is not None:
181166
should_check_early_stop = True
182167

183168
if should_check_early_stop:
184169
self._run_early_stopping_check(trainer, pl_module)
185170

186-
def __warn_deprecated_monitor_key(self):
187-
using_result_obj = os.environ.get('PL_USING_RESULT_OBJ', None)
188-
invalid_key = self.monitor not in ['val_loss', 'early_stop_on', 'val_early_stop_on', 'loss']
189-
if using_result_obj and not self.warned_result_obj and invalid_key:
190-
self.warned_result_obj = True
191-
rank_zero_warn(
192-
f"When using `EvalResult(early_stop_on=X)` or `TrainResult(early_stop_on=X)`"
193-
" the 'monitor' key of `EarlyStopping` has no effect. "
194-
f" Remove `EarlyStopping(monitor='{self.monitor}')` to fix."
195-
)
196-
197171
def _run_early_stopping_check(self, trainer, pl_module):
198172
"""
199173
Checks whether the early stopping condition is met

pytorch_lightning/callbacks/model_checkpoint.py

+16-30
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from pytorch_lightning.callbacks.base import Callback
3131
from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn
3232
from pytorch_lightning.utilities.cloud_io import get_filesystem
33+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3334

3435

3536
class ModelCheckpoint(Callback):
@@ -118,7 +119,7 @@ class ModelCheckpoint(Callback):
118119
def __init__(
119120
self,
120121
filepath: Optional[str] = None,
121-
monitor: str = "val_loss",
122+
monitor: str = "checkpoint_on",
122123
verbose: bool = False,
123124
save_last: bool = False,
124125
save_top_k: int = 1,
@@ -317,48 +318,30 @@ def on_pretrain_routine_start(self, trainer, pl_module):
317318
), "tried to make a checkpoint from non global_rank=0"
318319
self._fs.makedirs(self.dirpath, exist_ok=True)
319320

320-
def __warn_deprecated_monitor_key(self):
321-
using_result_obj = os.environ.get("PL_USING_RESULT_OBJ", None)
322-
invalid_key = self.monitor not in [
323-
"val_loss",
324-
"checkpoint_on",
325-
"loss",
326-
"val_checkpoint_on",
327-
]
328-
if using_result_obj and not self.warned_result_obj and invalid_key:
329-
self.warned_result_obj = True
330-
rank_zero_warn(
331-
f"When using `EvalResult(checkpoint_on=X)` or `TrainResult(checkpoint_on=X)`"
332-
" the 'monitor' key of `ModelCheckpoint` has no effect."
333-
f" Remove `ModelCheckpoint(monitor='{self.monitor}')` to fix."
334-
)
335-
336321
@rank_zero_only
337322
def on_validation_end(self, trainer, pl_module):
338323
# only run on main process
339324
if trainer.global_rank != 0:
340325
return
341326

342-
if trainer.running_sanity_check:
327+
# no models are saved
328+
if self.save_top_k == 0:
343329
return
344330

345-
# TODO: remove when dict results are deprecated
346-
self.__warn_deprecated_monitor_key()
331+
if trainer.running_sanity_check:
332+
return
347333

348334
metrics = trainer.logger_connector.callback_metrics
349335
epoch = trainer.current_epoch
350336

351-
# support structured results
352-
if metrics.get("checkpoint_on") is not None:
353-
self.monitor = "checkpoint_on"
354-
355-
# conditioned val metrics override conditioned train loop metrics
356-
if metrics.get("val_checkpoint_on") is not None:
357-
self.monitor = "val_checkpoint_on"
337+
# validate metric
338+
if not self._is_valid_monitor_key(metrics):
339+
keys = list(metrics.keys())
340+
m = f"""
341+
ModelCheckpoint(monitor='{self.monitor}') not found in the returned metrics ({keys}),
342+
"did you call result.log(f'{self.monitor}', tensor)?"""
343+
raise MisconfigurationException(m)
358344

359-
if self.save_top_k == 0:
360-
# no models are saved
361-
return
362345
if (
363346
self.epoch_last_check is not None
364347
and (epoch - self.epoch_last_check) < self.period
@@ -420,6 +403,9 @@ def on_validation_end(self, trainer, pl_module):
420403
if self.last_model_path and self.last_model_path != filepath:
421404
self._del_model(self.last_model_path)
422405

406+
def _is_valid_monitor_key(self, metrics):
407+
return self.monitor in metrics or len(metrics) == 0
408+
423409
def _do_check_save(
424410
self,
425411
filepath: str,

pytorch_lightning/core/step_result.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -801,8 +801,11 @@ def log_dict(
801801
)
802802

803803
def get_callback_metrics(self) -> dict:
804-
result = {'val_early_stop_on': self.early_stop_on, 'val_checkpoint_on': self.checkpoint_on}
805-
804+
result = {}
805+
if self.early_stop_on:
806+
result['early_stop_on'] = self.early_stop_on
807+
if self.checkpoint_on:
808+
result['checkpoint_on'] = self.checkpoint_on
806809
return result
807810

808811
def write(self, name: str, values: Union[Tensor, list], filename: str = 'predictions.pt'):

pytorch_lightning/trainer/__init__.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ def on_train_end(self, trainer, pl_module):
350350
filepath=os.getcwd(),
351351
save_top_k=True,
352352
verbose=True,
353-
monitor='val_loss',
353+
monitor='checkpoint_on',
354354
mode='min',
355355
prefix=''
356356
)
@@ -411,9 +411,9 @@ def on_train_end(self, trainer, pl_module):
411411
Callback for early stopping.
412412
early_stop_callback (:class:`pytorch_lightning.callbacks.EarlyStopping`)
413413
414-
- ``True``: A default callback monitoring ``'val_loss'`` (if dict is returned in validation loop) or
414+
- ``True``: A default callback monitoring ``'early_stop_on'`` (if dict is returned in validation loop) or
415415
``early_stopping_on`` (if :class:`~pytorch_lightning.core.step_result.Result` is returned) is created.
416-
Will raise an error if a dictionary is returned and ``'val_loss'`` is not found.
416+
Will raise an error if a dictionary is returned and ``'early_stop_on'`` is not found.
417417
Will raise an error if a :class:`~pytorch_lightning.core.step_result.Result` is returned
418418
and ``early_stopping_on`` was not specified.
419419
- ``False``: Early stopping will be disabled.
@@ -426,15 +426,15 @@ def on_train_end(self, trainer, pl_module):
426426
427427
# default used by the Trainer
428428
early_stop = EarlyStopping(
429-
monitor='val_loss',
429+
monitor='early_stop_on',
430430
patience=3,
431431
strict=False,
432432
verbose=False,
433433
mode='min'
434434
)
435435
trainer = Trainer(early_stop_callback=early_stop)
436436
437-
.. note:: If ``'val_loss'`` is not found will work as if early stopping is disabled.
437+
.. note:: If ``'early_stop_on'`` is not found will work as if early stopping is disabled.
438438
439439
fast_dev_run
440440
^^^^^^^^^^^^

pytorch_lightning/trainer/connectors/callback_connector.py

+4-12
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import os
22
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, ProgressBarBase, ProgressBar
33
from pytorch_lightning.utilities.exceptions import MisconfigurationException
4-
from pytorch_lightning.utilities.model_utils import is_overridden
54

65

76
class CallbackConnector:
@@ -38,7 +37,7 @@ def on_trainer_init(
3837
# configure checkpoint callback
3938
# it is important that this is the last callback to run
4039
# pass through the required args to figure out defaults
41-
checkpoint_callback = self.configure_checkpoint_callback(checkpoint_callback)
40+
checkpoint_callback = self.init_default_checkpoint_callback(checkpoint_callback)
4241
if checkpoint_callback:
4342
self.trainer.callbacks.append(checkpoint_callback)
4443

@@ -51,18 +50,11 @@ def on_trainer_init(
5150
progress_bar_refresh_rate, process_position
5251
)
5352

54-
def configure_checkpoint_callback(self, checkpoint_callback):
53+
def init_default_checkpoint_callback(self, checkpoint_callback):
5554
if checkpoint_callback is True:
56-
# when no val step is defined, use 'loss' otherwise 'val_loss'
57-
train_step_only = not is_overridden('validation_step', self.trainer.get_model())
58-
monitor_key = 'loss' if train_step_only else 'val_loss'
59-
checkpoint_callback = ModelCheckpoint(
60-
filepath=None,
61-
monitor=monitor_key
62-
)
55+
checkpoint_callback = ModelCheckpoint(filepath=None)
6356
elif checkpoint_callback is False:
6457
checkpoint_callback = None
65-
6658
if checkpoint_callback:
6759
checkpoint_callback.save_function = self.trainer.save_checkpoint
6860

@@ -71,7 +63,7 @@ def configure_checkpoint_callback(self, checkpoint_callback):
7163
def configure_early_stopping(self, early_stop_callback):
7264
if early_stop_callback is True or None:
7365
early_stop_callback = EarlyStopping(
74-
monitor='val_loss',
66+
monitor='early_stop_on',
7567
patience=3,
7668
strict=True,
7769
verbose=True,

pytorch_lightning/trainer/connectors/logger_connector.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,9 @@ def _log_on_evaluation_epoch_end_metrics(self, eval_results, using_eval_result):
110110
if using_eval_result:
111111
if isinstance(eval_results, list):
112112
for eval_result in eval_results:
113-
self.trainer.logger_connector.callback_metrics = eval_result.callback_metrics
113+
self.trainer.logger_connector.callback_metrics.update(eval_result.callback_metrics)
114114
else:
115-
self.trainer.logger_connector.callback_metrics = eval_results.callback_metrics
115+
self.trainer.logger_connector.callback_metrics.update(eval_results.callback_metrics)
116116
else:
117117
if isinstance(eval_results, list):
118118
for eval_result in eval_results:
@@ -121,13 +121,23 @@ def _log_on_evaluation_epoch_end_metrics(self, eval_results, using_eval_result):
121121
flat = {'val_loss': eval_result}
122122
else:
123123
flat = flatten_dict(eval_result)
124+
125+
# removing val_loss magic word to map to checkpoint + ES callback
126+
if 'val_loss' in flat:
127+
flat['checkpoint_on'] = flat['val_loss']
128+
flat['early_stop_on'] = flat['val_loss']
124129
self.trainer.logger_connector.callback_metrics.update(flat)
125130
else:
126131
# with a scalar return, auto set it to "val_loss" for callbacks
127132
if isinstance(eval_results, torch.Tensor):
128133
flat = {'val_loss': eval_results}
129134
else:
130135
flat = flatten_dict(eval_results)
136+
137+
# removing val_loss magic word to map to checkpoint + ES callback
138+
if 'val_loss' in flat:
139+
flat['checkpoint_on'] = flat['val_loss']
140+
flat['early_stop_on'] = flat['val_loss']
131141
self.trainer.logger_connector.callback_metrics.update(flat)
132142

133143
def __log_evaluation_epoch_metrics_2(self, eval_results, test_mode):
@@ -151,7 +161,7 @@ def __log_evaluation_epoch_metrics_2(self, eval_results, test_mode):
151161
if test_mode:
152162
callback_metrics = {}
153163
else:
154-
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.trainer.process_output(result)
164+
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.trainer.process_dict_result(result)
155165

156166
# eval loop returns all metrics
157167
dataloader_result_metrics = {**prog_bar_metrics, **log_metrics, **callback_metrics}
@@ -239,7 +249,7 @@ def log_train_epoch_end_metrics(self,
239249
epoch_log_metrics = epoch_output.epoch_log_metrics
240250
epoch_progress_bar_metrics = epoch_output.epoch_pbar_metrics
241251
else:
242-
_processed_outputs = self.trainer.process_output(epoch_output)
252+
_processed_outputs = self.trainer.process_dict_result(epoch_output)
243253
epoch_progress_bar_metrics = _processed_outputs[1]
244254
epoch_log_metrics = _processed_outputs[2]
245255
epoch_callback_metrics = _processed_outputs[3]

pytorch_lightning/trainer/logging.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def metrics_to_scalars(self, metrics):
5252

5353
return new_metrics
5454

55-
def process_output(self, output, train=False):
55+
def process_dict_result(self, output, train=False):
5656
"""Reduces output according to the training mode.
5757
5858
Separates loss from logging and progress bar metrics
@@ -147,6 +147,17 @@ def process_output(self, output, train=False):
147147
# no .item() because it will slow things down
148148
callback_metrics = recursive_detach(callback_metrics)
149149

150+
# replace loss with checkpoint_on
151+
if 'loss' in callback_metrics:
152+
callback_metrics['checkpoint_on'] = callback_metrics['loss']
153+
callback_metrics['early_stop_on'] = callback_metrics['loss']
154+
del callback_metrics['loss']
155+
156+
if 'val_loss' in callback_metrics:
157+
callback_metrics['checkpoint_on'] = callback_metrics['val_loss']
158+
callback_metrics['early_stop_on'] = callback_metrics['val_loss']
159+
del callback_metrics['val_loss']
160+
150161
return loss, progress_bar_metrics, log_metrics, callback_metrics, hiddens
151162

152163
def reduce_distributed_output(self, output, num_gpus):

0 commit comments

Comments
 (0)