Skip to content

Commit

Permalink
Fix serialization of AutoResume (#9616)
Browse files Browse the repository at this point in the history
* fix serialization of autoresume

* update undefined variables
  • Loading branch information
sararb committed Jul 4, 2024
1 parent 0f157ab commit 32286ed
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions nemo/lightning/resume.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import lightning_fabric as fl
import pytorch_lightning as pl

from nemo.lightning import io
from nemo.utils import logging
from nemo.utils.app_state import AppState
from nemo.utils.model_utils import uninject_model_parallel_rank


class Resume:
Expand All @@ -22,7 +24,7 @@ def setup(self, model, trainer: Union[pl.Trainer, fl.Fabric]):
trainer.checkpoint_callback.last_model_path = ckpt_path


class AutoResume(Resume):
class AutoResume(Resume, io.IOMixin):
"""Class that handles the logic for setting checkpoint paths and restoring from
checkpoints in NeMo.
"""
Expand Down Expand Up @@ -101,15 +103,15 @@ def nemo_path(self, model=None) -> Optional[Path]:
warn = f"There were no checkpoints found in checkpoint_dir or no checkpoint folder at checkpoint_dir :{checkpoint_dir}. "
if checkpoint is None:
warn += "Training from scratch."
elif checkpoint == resume_from_checkpoint:
warn += f"Training from {resume_from_checkpoint}."
elif checkpoint == self.path:
warn += f"Training from {self.path}."
logging.warning(warn)
else:
raise NotFoundError(
f"There were no checkpoints found in checkpoint_dir or no checkpoint folder at checkpoint_dir :{checkpoint_dir}. Cannot resume."
)
elif len(end_checkpoints) > 0:
if resume_past_end:
if self.resume_past_end:
if len(end_checkpoints) > 1:
if 'mp_rank' in str(end_checkpoints[0]):
checkpoint = end_checkpoints[0]
Expand Down

0 comments on commit 32286ed

Please sign in to comment.