diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_prompt_learning_config.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_prompt_learning_config.yaml index 8dc5b526bf3a..fce27c1df773 100644 --- a/examples/nlp/language_modeling/conf/megatron_gpt_prompt_learning_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_gpt_prompt_learning_config.yaml @@ -15,6 +15,7 @@ trainer: gradient_clip_val: 1.0 resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. benchmark: False + exp_manager: @@ -36,6 +37,14 @@ exp_manager: filename: 'megatron_gpt_prompt_tune--{val_loss:.3f}-{step}' model_parallel_size: ${model.tensor_model_parallel_size} save_best_model: True + create_early_stopping_callback: True + early_stopping_callback_params: + monitor: "val_loss" + mode: "min" + min_delta: 0.001 + patience: 10 + verbose: True + model: seed: 1234 diff --git a/examples/nlp/language_modeling/conf/megatron_t5_prompt_learning.yaml b/examples/nlp/language_modeling/conf/megatron_t5_prompt_learning.yaml index 3df6ea6dbac9..85252926777b 100644 --- a/examples/nlp/language_modeling/conf/megatron_t5_prompt_learning.yaml +++ b/examples/nlp/language_modeling/conf/megatron_t5_prompt_learning.yaml @@ -34,6 +34,13 @@ exp_manager: filename: "megatron_t5_prompt_tune--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}" model_parallel_size: ${model.tensor_model_parallel_size} save_best_model: True + create_early_stopping_callback: True + early_stopping_callback_params: + monitor: "val_loss" + mode: "min" + min_delta: 0.001 + patience: 10 + verbose: True model: seed: 1234 diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py index 769fb24680ea..0f0f358b997a 100644 --- a/nemo/utils/exp_manager.py +++ b/nemo/utils/exp_manager.py @@ -32,6 +32,7 @@ from hydra.utils import get_original_cwd from omegaconf import DictConfig, OmegaConf, open_dict from pytorch_lightning.callbacks import Callback, ModelCheckpoint +from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.callbacks.timer import Interval, Timer from pytorch_lightning.loggers import MLFlowLogger, TensorBoardLogger, WandbLogger from pytorch_lightning.loops import TrainingEpochLoop @@ -69,6 +70,21 @@ class CheckpointMisconfigurationError(NeMoBaseException): """ Raised when a mismatch between trainer.callbacks and exp_manager occurs""" +@dataclass +class EarlyStoppingParams: + monitor: str = "val_loss" # The metric that early stopping should consider. + mode: str = "min" # inform early stopping whether to look for increase or decrease in monitor. + min_delta: float = 0.001 # smallest change to consider as improvement. + patience: int = 10 # how many (continuous) validation cycles to wait with no improvement and stopping training. + verbose: bool = True + strict: bool = True + check_finite: bool = True + stopping_threshold: Optional[float] = None + divergence_threshold: Optional[float] = None + check_on_train_epoch_end: Optional[bool] = None + log_rank_zero_only: bool = False + + @dataclass class CallbackParams: filepath: Optional[str] = None # Deprecated @@ -153,6 +169,8 @@ class ExpManagerConfig: # Checkpointing parameters create_checkpoint_callback: Optional[bool] = True checkpoint_callback_params: Optional[CallbackParams] = CallbackParams() + create_early_stopping_callback: Optional[bool] = False + early_stopping_callback_params: Optional[EarlyStoppingParams] = EarlyStoppingParams() # Additional exp_manager arguments files_to_copy: Optional[List[str]] = None # logs timing of train/val/test steps @@ -272,6 +290,8 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo pytorch lightning trainer. The ModelCheckpoint saves the top 3 models with the best "val_loss", the most recent checkpoint under ``*last.ckpt``, and the final checkpoint after training completes under ``*end.ckpt``. Defaults to True. + - create_early_stopping_callback (bool): Flag to decide if early stopping should be used to stop training. Default is False. + See EarlyStoppingParams dataclass above. - files_to_copy (list): A list of files to copy to the experiment logging directory. Defaults to None which copies no files. - log_local_rank_0_only (bool): Whether to only create log files for local rank 0. Defaults to False. @@ -420,6 +440,10 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo ) trainer.callbacks.append(ema_callback) + if cfg.create_early_stopping_callback: + early_stop_callback = EarlyStopping(**cfg.early_stopping_callback_params) + trainer.callbacks.append(early_stop_callback) + if cfg.create_checkpoint_callback: configure_checkpointing( trainer, log_dir, checkpoint_name, cfg.resume_if_exists, cfg.checkpoint_callback_params diff --git a/tests/core/test_config_utils.py b/tests/core/test_config_utils.py index 4eccd90afada..b449fd8fc881 100644 --- a/tests/core/test_config_utils.py +++ b/tests/core/test_config_utils.py @@ -17,9 +17,11 @@ import pytest import pytorch_lightning as ptl +from pytorch_lightning.callbacks.early_stopping import EarlyStopping from nemo.core.config.pytorch_lightning import TrainerConfig from nemo.utils import config_utils +from nemo.utils.exp_manager import EarlyStoppingParams @pytest.fixture() @@ -126,3 +128,12 @@ def test_ptl_config(self): assert signatures_match assert cls_subset is None assert dataclass_subset is None + + @pytest.mark.unit + def test_early_stopping_config(self,): + result = config_utils.assert_dataclass_signature_match(EarlyStopping, EarlyStoppingParams) + signatures_match, cls_subset, dataclass_subset = result + + assert signatures_match + assert cls_subset is None + assert dataclass_subset is None