diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index ad221e1d8819..43a1d750f33c 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -18,6 +18,7 @@ import contextlib import functools +import glob import inspect import math import os @@ -1302,7 +1303,7 @@ def train( if resume_from_checkpoint is None: raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})") - if resume_from_checkpoint is not None: + if resume_from_checkpoint is not None and not is_sagemaker_mp_enabled(): self._load_from_checkpoint(resume_from_checkpoint) # If model was re-initialized, put it on the right device and update self.model_wrapped @@ -1401,6 +1402,9 @@ def _inner_training_loop( model = self._wrap_model(self.model_wrapped) + if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None: + self._load_from_checkpoint(resume_from_checkpoint, model) + # for the rest of this function `model` is the outside model, whether it was wrapped or not if model is not self.model: self.model_wrapped = model @@ -1666,6 +1670,8 @@ def _inner_training_loop( xm.rendezvous("load_best_model_at_end") elif args.local_rank != -1: dist.barrier() + elif is_sagemaker_mp_enabled(): + smp.barrier() self._load_best_model() @@ -1688,7 +1694,12 @@ def _inner_training_loop( return TrainOutput(self.state.global_step, train_loss, metrics) - def _load_from_checkpoint(self, resume_from_checkpoint): + def _load_from_checkpoint(self, resume_from_checkpoint, model=None): + + if model is None: + model = self.model + strict_load = is_sagemaker_mp_enabled() + if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)) and not os.path.isfile( os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME) ): @@ -1713,20 +1724,22 @@ def _load_from_checkpoint(self, resume_from_checkpoint): # We load the model state dict on the CPU to avoid an OOM error. state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu") # If the model is on the GPU, it still works! - load_result = self.model.load_state_dict(state_dict, strict=False) - self._issue_warnings_after_load(load_result) - + load_result = model.load_state_dict(state_dict, strict=strict_load) + if not strict_load: + self._issue_warnings_after_load(load_result) # release memory del state_dict else: # We load the sharded checkpoint - load_result = load_sharded_checkpoint(self.model, resume_from_checkpoint, strict=False) - self._issue_warnings_after_load(load_result) + load_result = load_sharded_checkpoint(model, resume_from_checkpoint, strict=strict_load) + if not strict_load: + self._issue_warnings_after_load(load_result) def _load_best_model(self): logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).") - best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME) + strict_load = is_sagemaker_mp_enabled() + model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model if os.path.exists(best_model_path): if self.deepspeed: # temp hack until Deepspeed fixes the problem with resume from an existing engine that did some stepping @@ -1743,12 +1756,13 @@ def _load_best_model(self): # We load the model state dict on the CPU to avoid an OOM error. state_dict = torch.load(best_model_path, map_location="cpu") # If the model is on the GPU, it still works! - load_result = self.model.load_state_dict(state_dict, strict=False) - self._issue_warnings_after_load(load_result) + load_result = model.load_state_dict(state_dict, strict=strict_load) + if not strict_load: + self._issue_warnings_after_load(load_result) elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)): - # Best model is a sharded checkpoint - load_result = load_sharded_checkpoint(self.model, self.state.best_model_checkpoint, strict=False) - self._issue_warnings_after_load(load_result) + load_result = load_sharded_checkpoint(model, self.state.best_model_checkpoint, strict=strict_load) + if not strict_load: + self._issue_warnings_after_load(load_result) else: logger.warning( f"Could not locate the best model at {best_model_path}, if you are running a distributed training " @@ -1886,17 +1900,21 @@ def _save_checkpoint(self, model, trial, metrics=None): xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) reissue_pt_warnings(caught_warnings) elif is_sagemaker_mp_enabled(): - if smp.rdp_rank() == 0: - # Consolidate the state dict on all processed of rdp_rank 0 - opt_state_dict = self.optimizer.state_dict() - # Save it and the scheduler on the main process - if self.args.should_save: - torch.save(opt_state_dict, os.path.join(output_dir, OPTIMIZER_NAME)) - with warnings.catch_warnings(record=True) as caught_warnings: - torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) - reissue_pt_warnings(caught_warnings) - if self.do_grad_scaling: - torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) + opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False) + smp.barrier() + if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state: + smp.save( + opt_state_dict, + os.path.join(output_dir, OPTIMIZER_NAME), + partial=True, + v3=smp.state.cfg.shard_optimizer_state, + ) + if self.args.should_save: + with warnings.catch_warnings(record=True) as caught_warnings: + torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) + reissue_pt_warnings(caught_warnings) + if self.do_grad_scaling: + torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) elif self.args.should_save and not self.deepspeed: # deepspeed.save_checkpoint above saves model/optim/sched torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) @@ -1945,6 +1963,7 @@ def _save_checkpoint(self, model, trial, metrics=None): # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may # not yet exist. os.makedirs(output_dir, exist_ok=True) + local_rank = xm.get_local_ordinal() if is_torch_tpu_available() else self.args.local_rank if local_rank == -1: torch.save(rng_states, os.path.join(output_dir, "rng_state.pth")) @@ -1967,9 +1986,12 @@ def _load_optimizer_and_scheduler(self, checkpoint): # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init return - if os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME)) and os.path.isfile( - os.path.join(checkpoint, SCHEDULER_NAME) - ): + checkpoint_file_exists = ( + glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*") + if is_sagemaker_mp_enabled() + else os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME)) + ) + if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)): # Load in optimizer and scheduler states if is_torch_tpu_available(): # On TPU we have to take some extra precautions to properly load the states on the right device. @@ -1985,9 +2007,16 @@ def _load_optimizer_and_scheduler(self, checkpoint): self.lr_scheduler.load_state_dict(lr_scheduler_state) else: map_location = "cpu" if is_sagemaker_mp_enabled() else self.args.device - self.optimizer.load_state_dict( - torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location) - ) + if is_sagemaker_mp_enabled(): + + def opt_load_hook(mod, opt): + opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True)) + + self.model_wrapped.register_post_step_hook(opt_load_hook) + else: + self.optimizer.load_state_dict( + torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location) + ) with warnings.catch_warnings(record=True) as caught_warnings: self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME))) reissue_pt_warnings(caught_warnings)