Skip to content

Commit

Permalink
minor updates and added dataclass check
Browse files Browse the repository at this point in the history
Signed-off-by: arendu <[email protected]>
  • Loading branch information
arendu committed Feb 16, 2023
1 parent 0aefe59 commit 2f1111d
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ exp_manager:
filename: "megatron_t5_prompt_tune--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}"
model_parallel_size: ${model.tensor_model_parallel_size}
save_best_model: True
create_early_stopping_callback: True
early_stopping_callback_params:
monitor: "val_loss"
mode: "min"
Expand Down
3 changes: 2 additions & 1 deletion nemo/utils/exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,8 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo
pytorch lightning trainer. The ModelCheckpoint saves the top 3 models with the best "val_loss", the most
recent checkpoint under ``*last.ckpt``, and the final checkpoint after training completes under ``*end.ckpt``.
Defaults to True.
- create_early_stopping_callback (bool): Flag to decide if early stopping should be used to stop training. Default is False. See EarlyStoppingParams dataclass above.
- create_early_stopping_callback (bool): Flag to decide if early stopping should be used to stop training. Default is False.
See EarlyStoppingParams dataclass above.
- files_to_copy (list): A list of files to copy to the experiment logging directory. Defaults to None which
copies no files.
- log_local_rank_0_only (bool): Whether to only create log files for local rank 0. Defaults to False.
Expand Down
11 changes: 11 additions & 0 deletions tests/core/test_config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import pytest
import pytorch_lightning as ptl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from nemo.utils.exp_manager import EarlyStoppingParams

from nemo.core.config.pytorch_lightning import TrainerConfig
from nemo.utils import config_utils
Expand Down Expand Up @@ -126,3 +128,12 @@ def test_ptl_config(self):
assert signatures_match
assert cls_subset is None
assert dataclass_subset is None

@pytest.mark.unit
def test_early_stopping_config(self,):
result = config_utils.assert_dataclass_signature_match(EarlyStopping, EarlyStoppingParams)
signatures_match, cls_subset, dataclass_subset = result

assert signatures_match
assert cls_subset is None
assert dataclass_subset is None

0 comments on commit 2f1111d

Please sign in to comment.