diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index eb0a3ce2ee54..0eb232df5ac5 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -17,6 +17,7 @@ """ import contextlib +import glob import inspect import math import os @@ -1193,7 +1194,13 @@ def train( raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})") if resume_from_checkpoint is not None: - if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)): + # SMP partial checkpoints are in {filename}_{pp_rank()}_{tp_rank()} or {filename}_{pp_rank()}_{tp_rank()}_{rdp_rank()} format. + checkpoint_file_exists = ( + glob.glob(os.path.join(resume_from_checkpoint, WEIGHTS_NAME) + "_*") + if is_sagemaker_mp_enabled() + else os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)) + ) + if not checkpoint_file_exists: raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}") logger.info(f"Loading model from {resume_from_checkpoint}).") @@ -1211,6 +1218,9 @@ def train( if args.deepspeed: # will be resumed in deepspeed_init pass + elif is_sagemaker_mp_enabled(): + # will be resumed after model is wrapped + pass else: # We load the model state dict on the CPU to avoid an OOM error. state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu") @@ -1299,6 +1309,10 @@ def train( model = self._wrap_model(self.model_wrapped) + if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None: + state_dict = smp.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), partial=True) + model.load_state_dict(state_dict) + # for the rest of this function `model` is the outside model, whether it was wrapped or not if model is not self.model: self.model_wrapped = model @@ -1561,13 +1575,19 @@ def train( xm.rendezvous("load_best_model_at_end") elif args.local_rank != -1: dist.barrier() + elif is_sagemaker_mp_enabled(): + smp.barrier() logger.info( f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})." ) best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME) - if os.path.exists(best_model_path): + checkpoint_file_exists = ( + glob.glob(best_model_path + "_*") if is_sagemaker_mp_enabled() else os.path.exists(best_model_path) + ) + + if checkpoint_file_exists: if self.deepspeed: # temp hack until Deepspeed fixes the problem with resume from an existing engine that did some stepping deepspeed_engine, optimizer, lr_scheduler = deepspeed_reinit(self) @@ -1579,6 +1599,9 @@ def train( self.deepspeed.load_checkpoint( self.state.best_model_checkpoint, load_optimizer_states=True, load_lr_scheduler_states=True ) + elif is_sagemaker_mp_enabled(): + state_dict = smp.load(best_model_path, partial=True) + model.load_state_dict(state_dict) else: # We load the model state dict on the CPU to avoid an OOM error. state_dict = torch.load(best_model_path, map_location="cpu") @@ -1741,17 +1764,20 @@ def _save_checkpoint(self, model, trial, metrics=None): xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) reissue_pt_warnings(caught_warnings) elif is_sagemaker_mp_enabled(): - if smp.rdp_rank() == 0: - # Consolidate the state dict on all processed of rdp_rank 0 - opt_state_dict = self.optimizer.state_dict() - # Save it and the scheduler on the main process - if self.args.should_save: - torch.save(opt_state_dict, os.path.join(output_dir, OPTIMIZER_NAME)) - with warnings.catch_warnings(record=True) as caught_warnings: - torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) - reissue_pt_warnings(caught_warnings) - if self.do_grad_scaling: - torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) + opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False) + if self.args.should_save or smp.state.cfg.shard_optimizer_state: + smp.save( + opt_state_dict, + os.path.join(output_dir, OPTIMIZER_NAME), + partial=True, + v3=smp.state.cfg.shard_optimizer_state, + ) + if self.args.should_save: + with warnings.catch_warnings(record=True) as caught_warnings: + torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) + reissue_pt_warnings(caught_warnings) + if self.do_grad_scaling: + torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) elif self.args.should_save and not self.deepspeed: # deepspeed.save_checkpoint above saves model/optim/sched torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) @@ -1822,9 +1848,12 @@ def _load_optimizer_and_scheduler(self, checkpoint): # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init return - if os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME)) and os.path.isfile( - os.path.join(checkpoint, SCHEDULER_NAME) - ): + checkpoint_file_exists = ( + glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*") + if is_sagemaker_mp_enabled() + else os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME)) + ) + if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)): # Load in optimizer and scheduler states if is_torch_tpu_available(): # On TPU we have to take some extra precautions to properly load the states on the right device. @@ -1840,9 +1869,16 @@ def _load_optimizer_and_scheduler(self, checkpoint): self.lr_scheduler.load_state_dict(lr_scheduler_state) else: map_location = "cpu" if is_sagemaker_mp_enabled() else self.args.device - self.optimizer.load_state_dict( - torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location) - ) + if is_sagemaker_mp_enabled(): + + def opt_load_hook(mod, opt): + opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True)) + + self.model_wrapped.register_post_step_hook(opt_load_hook) + else: + self.optimizer.load_state_dict( + torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location) + ) with warnings.catch_warnings(record=True) as caught_warnings: self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME))) reissue_pt_warnings(caught_warnings) @@ -2114,7 +2150,7 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa self._save_tpu(output_dir) elif is_sagemaker_mp_enabled(): # Calling the state_dict needs to be done on the wrapped model and on all processes. - state_dict = self.model_wrapped.state_dict() + state_dict = self.model_wrapped.local_state_dict() if self.args.should_save: self._save(output_dir, state_dict=state_dict) elif ( @@ -2202,9 +2238,15 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") if state_dict is None: state_dict = self.model.state_dict() - torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) + if is_sagemaker_mp_enabled(): + smp.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME), partial=True) + else: + torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) else: - self.model.save_pretrained(output_dir, state_dict=state_dict) + if is_sagemaker_mp_enabled(): + smp.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME), partial=True) + else: + self.model.save_pretrained(output_dir, state_dict=state_dict) if self.tokenizer is not None: self.tokenizer.save_pretrained(output_dir) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index cc0a5ec83570..2b6fc957b205 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1219,7 +1219,7 @@ def should_save(self): return self.local_process_index == 0 else: if is_sagemaker_mp_enabled(): - return smp.rank() == 0 + return smp.rdp_rank() == 0 else: return self.process_index == 0