From a68e91b81ffafdb34b9848437ae194816c70e3a3 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 24 Feb 2021 00:56:05 +0100 Subject: [PATCH 1/4] Remove EarlyStopping(mode='auto') --- CHANGELOG.md | 3 ++ pytorch_lightning/callbacks/early_stopping.py | 37 ++++--------------- tests/deprecated_api/test_remove_1-3.py | 3 -- 3 files changed, 10 insertions(+), 33 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4e9a0fea8cf2c..ee08d5267d269 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Removed +- Removed `mode='auto'` from `EarlyStopping` ([#6167](https://github.com/PyTorchLightning/pytorch-lightning/pull/6167)) + + ### Fixed - Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011)) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 384ce9699f60e..409de16457579 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -23,7 +23,7 @@ import torch from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn +from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -40,23 +40,18 @@ class EarlyStopping(Callback): patience: number of validation epochs with no improvement after which training will be stopped. Default: ``3``. verbose: verbosity mode. Default: ``False``. - mode: one of {auto, min, max}. In `min` mode, + mode: one of {min, max}. In `min` mode, training will stop when the quantity monitored has stopped decreasing; in `max` mode it will stop when the quantity - monitored has stopped increasing; in `auto` - mode, the direction is automatically inferred - from the name of the monitored quantity. - - .. warning:: - Setting ``mode='auto'`` has been deprecated in v1.1 and will be removed in v1.3. + monitored has stopped increasing. strict: whether to crash the training if `monitor` is not found in the validation metrics. Default: ``True``. Raises: MisconfigurationException: - If ``mode`` is none of ``"min"``, ``"max"``, and ``"auto"``. + If ``mode`` is none of ``"min"``, ``"max"``. RuntimeError: If the metric ``monitor`` is not available. @@ -78,7 +73,7 @@ def __init__( min_delta: float = 0.0, patience: int = 3, verbose: bool = False, - mode: str = 'auto', + mode: str = 'min', strict: bool = True, ): super().__init__() @@ -92,31 +87,13 @@ def __init__( self.mode = mode self.warned_result_obj = False - self.__init_monitor_mode() + if self.mode not in self.mode_dict: + raise MisconfigurationException(f"`mode` can be {', '.join(self.mode_dict.keys())}, got {self.mode}") self.min_delta *= 1 if self.monitor_op == torch.gt else -1 torch_inf = torch.tensor(np.Inf) self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf - def __init_monitor_mode(self): - if self.mode not in self.mode_dict and self.mode != 'auto': - raise MisconfigurationException(f"`mode` can be auto, {', '.join(self.mode_dict.keys())}, got {self.mode}") - - # TODO: Update with MisconfigurationException when auto mode is removed in v1.3 - if self.mode == 'auto': - rank_zero_warn( - "mode='auto' is deprecated in v1.1 and will be removed in v1.3." - " Default value for mode with be 'min' in v1.3.", DeprecationWarning - ) - - if "acc" in self.monitor or self.monitor.startswith("fmeasure"): - self.mode = 'max' - else: - self.mode = 'min' - - if self.verbose > 0: - rank_zero_info(f'EarlyStopping mode set to {self.mode} for monitoring {self.monitor}.') - def _validate_condition_metric(self, logs): monitor_val = logs.get(self.monitor) diff --git a/tests/deprecated_api/test_remove_1-3.py b/tests/deprecated_api/test_remove_1-3.py index 99cb280e96797..0b61e854d23c8 100644 --- a/tests/deprecated_api/test_remove_1-3.py +++ b/tests/deprecated_api/test_remove_1-3.py @@ -36,9 +36,6 @@ def test_v1_3_0_deprecated_arguments(tmpdir): with pytest.deprecated_call(match='will be removed in v1.3'): ModelCheckpoint(mode='auto') - with pytest.deprecated_call(match='will be removed in v1.3'): - EarlyStopping(mode='auto') - with pytest.deprecated_call(match="The setter for self.hparams in LightningModule is deprecated"): class DeprecatedHparamsModel(LightningModule): From 52d55e3fdfe16fa8404e1f699a9b605df07d6a32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 24 Feb 2021 03:14:28 +0100 Subject: [PATCH 2/4] Apply suggestions from code review Co-authored-by: Roger Shieh --- pytorch_lightning/callbacks/early_stopping.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 409de16457579..f30a605f2de3c 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -40,9 +40,9 @@ class EarlyStopping(Callback): patience: number of validation epochs with no improvement after which training will be stopped. Default: ``3``. verbose: verbosity mode. Default: ``False``. - mode: one of {min, max}. In `min` mode, + mode: one of ``'min'``, ``'max'``. In ``min`` mode, training will stop when the quantity - monitored has stopped decreasing; in `max` + monitored has stopped decreasing; in ``max`` mode it will stop when the quantity monitored has stopped increasing. From 83962d93c665728897d7bbb010341d1b13427407 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 24 Feb 2021 13:58:14 +0100 Subject: [PATCH 3/4] Apply suggestions from code review Co-authored-by: Rohit Gupta --- pytorch_lightning/callbacks/early_stopping.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index f30a605f2de3c..d188aebe96489 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -40,9 +40,9 @@ class EarlyStopping(Callback): patience: number of validation epochs with no improvement after which training will be stopped. Default: ``3``. verbose: verbosity mode. Default: ``False``. - mode: one of ``'min'``, ``'max'``. In ``min`` mode, + mode: one of ``'min'``, ``'max'``. In ``'min'`` mode, training will stop when the quantity - monitored has stopped decreasing; in ``max`` + monitored has stopped decreasing and in ``'max'`` mode it will stop when the quantity monitored has stopped increasing. @@ -51,7 +51,7 @@ class EarlyStopping(Callback): Raises: MisconfigurationException: - If ``mode`` is none of ``"min"``, ``"max"``. + If ``mode`` is none of ``"min"`` or ``"max"``. RuntimeError: If the metric ``monitor`` is not available. From 3fa2d052851e083a54f78440cded0cdcde08243a Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 24 Feb 2021 14:00:16 +0100 Subject: [PATCH 4/4] Update test --- tests/callbacks/test_early_stopping.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index 9d326f045544e..f36b68c0dacf2 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -334,7 +334,7 @@ def validation_epoch_end(self, outputs): # Compute min_epochs latest step by_min_epochs = min_epochs * limit_train_batches - # Make sure the trainer stops for the max of all minimun requirements + # Make sure the trainer stops for the max of all minimum requirements assert trainer.global_step == max(min_steps, by_early_stopping, by_min_epochs), \ (trainer.global_step, max(min_steps, by_early_stopping, by_min_epochs), step_freeze, min_steps, min_epochs) @@ -342,5 +342,5 @@ def validation_epoch_end(self, outputs): def test_early_stopping_mode_options(): - with pytest.raises(MisconfigurationException, match="`mode` can be auto, .* got unknown_option"): + with pytest.raises(MisconfigurationException, match="`mode` can be .* got unknown_option"): EarlyStopping(mode="unknown_option")