-
Notifications
You must be signed in to change notification settings - Fork 32k
Revised partial checkpoint support for Sagemaker Model Parallel #16950
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |
| """ | ||
|
|
||
| import contextlib | ||
| import glob | ||
| import inspect | ||
| import math | ||
| import os | ||
|
|
@@ -1193,7 +1194,13 @@ def train( | |
| raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})") | ||
|
|
||
| if resume_from_checkpoint is not None: | ||
| if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)): | ||
| # SMP partial checkpoints are in {filename}_{pp_rank()}_{tp_rank()} or {filename}_{pp_rank()}_{tp_rank()}_{rdp_rank()} format. | ||
| checkpoint_file_exists = ( | ||
| glob.glob(os.path.join(resume_from_checkpoint, WEIGHTS_NAME) + "_*") | ||
| if is_sagemaker_mp_enabled() | ||
| else os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)) | ||
| ) | ||
|
Comment on lines
+1198
to
+1202
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is used several times, could we refactor it in a util function that takes the filename? |
||
| if not checkpoint_file_exists: | ||
| raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}") | ||
|
|
||
| logger.info(f"Loading model from {resume_from_checkpoint}).") | ||
|
|
@@ -1211,6 +1218,9 @@ def train( | |
| if args.deepspeed: | ||
| # will be resumed in deepspeed_init | ||
| pass | ||
| elif is_sagemaker_mp_enabled(): | ||
| # will be resumed after model is wrapped | ||
| pass | ||
| else: | ||
| # We load the model state dict on the CPU to avoid an OOM error. | ||
| state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu") | ||
|
|
@@ -1299,6 +1309,10 @@ def train( | |
|
|
||
| model = self._wrap_model(self.model_wrapped) | ||
|
|
||
| if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None: | ||
| state_dict = smp.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), partial=True) | ||
| model.load_state_dict(state_dict) | ||
|
|
||
| # for the rest of this function `model` is the outside model, whether it was wrapped or not | ||
| if model is not self.model: | ||
| self.model_wrapped = model | ||
|
|
@@ -1561,13 +1575,19 @@ def train( | |
| xm.rendezvous("load_best_model_at_end") | ||
| elif args.local_rank != -1: | ||
| dist.barrier() | ||
| elif is_sagemaker_mp_enabled(): | ||
| smp.barrier() | ||
|
|
||
| logger.info( | ||
| f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})." | ||
| ) | ||
|
|
||
| best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME) | ||
| if os.path.exists(best_model_path): | ||
| checkpoint_file_exists = ( | ||
| glob.glob(best_model_path + "_*") if is_sagemaker_mp_enabled() else os.path.exists(best_model_path) | ||
| ) | ||
|
|
||
| if checkpoint_file_exists: | ||
| if self.deepspeed: | ||
| # temp hack until Deepspeed fixes the problem with resume from an existing engine that did some stepping | ||
| deepspeed_engine, optimizer, lr_scheduler = deepspeed_reinit(self) | ||
|
|
@@ -1579,6 +1599,9 @@ def train( | |
| self.deepspeed.load_checkpoint( | ||
| self.state.best_model_checkpoint, load_optimizer_states=True, load_lr_scheduler_states=True | ||
| ) | ||
| elif is_sagemaker_mp_enabled(): | ||
| state_dict = smp.load(best_model_path, partial=True) | ||
| model.load_state_dict(state_dict) | ||
| else: | ||
| # We load the model state dict on the CPU to avoid an OOM error. | ||
| state_dict = torch.load(best_model_path, map_location="cpu") | ||
|
|
@@ -1741,17 +1764,20 @@ def _save_checkpoint(self, model, trial, metrics=None): | |
| xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) | ||
| reissue_pt_warnings(caught_warnings) | ||
| elif is_sagemaker_mp_enabled(): | ||
| if smp.rdp_rank() == 0: | ||
| # Consolidate the state dict on all processed of rdp_rank 0 | ||
| opt_state_dict = self.optimizer.state_dict() | ||
| # Save it and the scheduler on the main process | ||
| if self.args.should_save: | ||
| torch.save(opt_state_dict, os.path.join(output_dir, OPTIMIZER_NAME)) | ||
| with warnings.catch_warnings(record=True) as caught_warnings: | ||
| torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) | ||
| reissue_pt_warnings(caught_warnings) | ||
| if self.do_grad_scaling: | ||
| torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) | ||
| opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False) | ||
| if self.args.should_save or smp.state.cfg.shard_optimizer_state: | ||
| smp.save( | ||
| opt_state_dict, | ||
| os.path.join(output_dir, OPTIMIZER_NAME), | ||
| partial=True, | ||
| v3=smp.state.cfg.shard_optimizer_state, | ||
| ) | ||
| if self.args.should_save: | ||
| with warnings.catch_warnings(record=True) as caught_warnings: | ||
| torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) | ||
| reissue_pt_warnings(caught_warnings) | ||
| if self.do_grad_scaling: | ||
| torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) | ||
| elif self.args.should_save and not self.deepspeed: | ||
| # deepspeed.save_checkpoint above saves model/optim/sched | ||
| torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) | ||
|
|
@@ -1822,9 +1848,12 @@ def _load_optimizer_and_scheduler(self, checkpoint): | |
| # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init | ||
| return | ||
|
|
||
| if os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME)) and os.path.isfile( | ||
| os.path.join(checkpoint, SCHEDULER_NAME) | ||
| ): | ||
| checkpoint_file_exists = ( | ||
| glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*") | ||
| if is_sagemaker_mp_enabled() | ||
| else os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME)) | ||
| ) | ||
| if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)): | ||
| # Load in optimizer and scheduler states | ||
| if is_torch_tpu_available(): | ||
| # On TPU we have to take some extra precautions to properly load the states on the right device. | ||
|
|
@@ -1840,9 +1869,16 @@ def _load_optimizer_and_scheduler(self, checkpoint): | |
| self.lr_scheduler.load_state_dict(lr_scheduler_state) | ||
| else: | ||
| map_location = "cpu" if is_sagemaker_mp_enabled() else self.args.device | ||
| self.optimizer.load_state_dict( | ||
| torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location) | ||
| ) | ||
| if is_sagemaker_mp_enabled(): | ||
|
|
||
| def opt_load_hook(mod, opt): | ||
| opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True)) | ||
|
|
||
| self.model_wrapped.register_post_step_hook(opt_load_hook) | ||
| else: | ||
| self.optimizer.load_state_dict( | ||
| torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location) | ||
| ) | ||
| with warnings.catch_warnings(record=True) as caught_warnings: | ||
| self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME))) | ||
| reissue_pt_warnings(caught_warnings) | ||
|
|
@@ -2114,7 +2150,7 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa | |
| self._save_tpu(output_dir) | ||
| elif is_sagemaker_mp_enabled(): | ||
| # Calling the state_dict needs to be done on the wrapped model and on all processes. | ||
| state_dict = self.model_wrapped.state_dict() | ||
| state_dict = self.model_wrapped.local_state_dict() | ||
| if self.args.should_save: | ||
| self._save(output_dir, state_dict=state_dict) | ||
| elif ( | ||
|
|
@@ -2202,9 +2238,15 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): | |
| logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") | ||
| if state_dict is None: | ||
| state_dict = self.model.state_dict() | ||
| torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) | ||
| if is_sagemaker_mp_enabled(): | ||
| smp.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME), partial=True) | ||
| else: | ||
| torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) | ||
| else: | ||
| self.model.save_pretrained(output_dir, state_dict=state_dict) | ||
| if is_sagemaker_mp_enabled(): | ||
| smp.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME), partial=True) | ||
| else: | ||
| self.model.save_pretrained(output_dir, state_dict=state_dict) | ||
|
Comment on lines
+2246
to
+2249
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No calling
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. SMP checkpoints are saved partially hence we do not want to shard SMP checkpoints. In order to use
|
||
| if self.tokenizer is not None: | ||
| self.tokenizer.save_pretrained(output_dir) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is used several times, could we refactor it in a util function that takes the filename?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will do that.