From 1d798ba61e8ff22d3ab5a7700c34f1d876249ea7 Mon Sep 17 00:00:00 2001 From: Cavdar Date: Mon, 11 Apr 2022 17:54:14 -0700 Subject: [PATCH 1/2] partial checkpoint support for SMP --- src/transformers/modeling_utils.py | 53 +++++++------- src/transformers/trainer.py | 107 ++++++++++++++++++++++------- src/transformers/training_args.py | 28 +++++++- 3 files changed, 139 insertions(+), 49 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 7f1b12386202..9bda8bc7504a 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1361,32 +1361,39 @@ def save_pretrained( if ignore_key in state_dict.keys(): del state_dict[ignore_key] - # Shard the model if it is too big. - shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size) + from .utils import is_sagemaker_mp_enabled + if is_sagemaker_mp_enabled(): + # Do not shard checkpoints when sagemaker model parallel is enabled + output_model_file = os.path.join(save_directory, WEIGHTS_NAME) + save_function(state_dict, output_model_file) + logger.info(f"Model weights saved in {output_model_file}") + else: + # Shard the model if it is too big. + shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size) - # Clean the folder from a previous save - for filename in os.listdir(save_directory): - full_filename = os.path.join(save_directory, filename) - if filename.startswith(WEIGHTS_NAME[:-4]) and os.path.isfile(full_filename): - os.remove(full_filename) + # Clean the folder from a previous save + for filename in os.listdir(save_directory): + full_filename = os.path.join(save_directory, filename) + if filename.startswith(WEIGHTS_NAME[:-4]) and os.path.isfile(full_filename): + os.remove(full_filename) - # Save the model - for shard_file, shard in shards.items(): - save_function(shard, os.path.join(save_directory, shard_file)) + # Save the model + for shard_file, shard in shards.items(): + save_function(shard, os.path.join(save_directory, shard_file)) - if index is None: - logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}") - else: - save_index_file = os.path.join(save_directory, WEIGHTS_INDEX_NAME) - # Save the index as well - with open(save_index_file, "w", encoding="utf-8") as f: - content = json.dumps(index, indent=2, sort_keys=True) + "\n" - f.write(content) - logger.info( - f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " - f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the " - f"index located at {save_index_file}." - ) + if index is None: + logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}") + else: + save_index_file = os.path.join(save_directory, WEIGHTS_INDEX_NAME) + # Save the index as well + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + logger.info( + f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " + f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) if push_to_hub: url = self._push_to_hub(repo, commit_message=commit_message) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 0c3ea207bb14..1ba11a0f2896 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -836,16 +836,20 @@ def create_optimizer(self): We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the Trainer's init through `optimizers`, or subclass and override this method in a subclass. """ + if is_sagemaker_mp_enabled(): + opt_model = self.model_wrapped + else: + opt_model = self.model if self.optimizer is None: - decay_parameters = get_parameter_names(self.model, [nn.LayerNorm]) + decay_parameters = get_parameter_names(opt_model, [nn.LayerNorm]) decay_parameters = [name for name in decay_parameters if "bias" not in name] optimizer_grouped_parameters = [ { - "params": [p for n, p in self.model.named_parameters() if n in decay_parameters], + "params": [p for n, p in opt_model.named_parameters() if n in decay_parameters], "weight_decay": self.args.weight_decay, }, { - "params": [p for n, p in self.model.named_parameters() if n not in decay_parameters], + "params": [p for n, p in opt_model.named_parameters() if n not in decay_parameters], "weight_decay": 0.0, }, ] @@ -1032,6 +1036,8 @@ def _wrap_model(self, model, training=True): # Wrapping the base model twice in a DistributedModel will raise an error. if isinstance(self.model_wrapped, smp.model.DistributedModel): return self.model_wrapped + if self.args.smp_tensor_parallel_full_model: + smp.set_tensor_parallelism(model) return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps) # already initialized its own DDP and AMP @@ -1165,8 +1171,16 @@ def train( if resume_from_checkpoint is None: raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})") + if resume_from_checkpoint is not None: - if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)): + if self.args.smp_load_partial: + import glob + # SMP partial checkpoints are in {filename}_{pp_rank()}_{tp_rank()} or {filename}_{pp_rank()}_{tp_rank()}_{rdp_rank()} format. + checkpoint_file_exists = glob.glob(os.path.join(resume_from_checkpoint, WEIGHTS_NAME) + "_*") + else: + checkpoint_file_exists = os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)) + + if not checkpoint_file_exists: raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}") logger.info(f"Loading model from {resume_from_checkpoint}).") @@ -1185,13 +1199,14 @@ def train( # will be resumed in deepspeed_init pass else: - # We load the model state dict on the CPU to avoid an OOM error. - state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu") - # If the model is on the GPU, it still works! - self._load_state_dict_in_model(state_dict) + if not is_sagemaker_mp_enabled(): + # We load the model state dict on the CPU to avoid an OOM error. + state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu") + # If the model is on the GPU, it still works! + self._load_state_dict_in_model(state_dict) - # release memory - del state_dict + # release memory + del state_dict # If model was re-initialized, put it on the right device and update self.model_wrapped if model_reloaded: @@ -1272,6 +1287,10 @@ def train( model = self._wrap_model(self.model_wrapped) + if resume_from_checkpoint is not None and is_sagemaker_mp_enabled(): + state_dict = smp.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), partial=self.args.smp_load_partial) + model.load_state_dict(state_dict) + # for the rest of this function `model` is the outside model, whether it was wrapped or not if model is not self.model: self.model_wrapped = model @@ -1534,13 +1553,20 @@ def train( xm.rendezvous("load_best_model_at_end") elif args.local_rank != -1: dist.barrier() + elif is_sagemaker_mp_enabled(): + smp.barrier() logger.info( f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})." ) best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME) - if os.path.exists(best_model_path): + if self.args.smp_load_partial: + import glob + # SMP partial checkpoints are in {filename}_{pp_rank()}_{tp_rank()} or {filename}_{pp_rank()}_{tp_rank()}_{rdp_rank()} format. + smp_checkpoint_file_exists = glob.glob(best_model_path + "_*") + + if os.path.exists(best_model_path) or smp_checkpoint_file_exists: if self.deepspeed: # temp hack until Deepspeed fixes the problem with resume from an existing engine that did some stepping deepspeed_engine, optimizer, lr_scheduler = deepspeed_reinit(self) @@ -1554,9 +1580,13 @@ def train( ) else: # We load the model state dict on the CPU to avoid an OOM error. - state_dict = torch.load(best_model_path, map_location="cpu") - # If the model is on the GPU, it still works! - self._load_state_dict_in_model(state_dict) + if is_sagemaker_mp_enabled(): + state_dict = smp.load(best_model_path, partial=self.args.smp_load_partial) + model.load_state_dict(state_dict) + else: + state_dict = torch.load(best_model_path, map_location="cpu") + # If the model is on the GPU, it still works! + self._load_state_dict_in_model(state_dict) else: logger.warning( f"Could not locate the best model at {best_model_path}, if you are running a distributed training " @@ -1714,12 +1744,15 @@ def _save_checkpoint(self, model, trial, metrics=None): xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) reissue_pt_warnings(caught_warnings) elif is_sagemaker_mp_enabled(): - if smp.rdp_rank() == 0: - # Consolidate the state dict on all processed of rdp_rank 0 + # Consolidate the state dict on all processes + if self.args.smp_save_partial: + opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False) + else: opt_state_dict = self.optimizer.state_dict() + if self.args.should_save or smp.state.cfg.shard_optimizer_state: + smp.save(opt_state_dict, os.path.join(output_dir, OPTIMIZER_NAME), partial=self.args.smp_save_partial, v3=smp.state.cfg.shard_optimizer_state) # Save it and the scheduler on the main process if self.args.should_save: - torch.save(opt_state_dict, os.path.join(output_dir, OPTIMIZER_NAME)) with warnings.catch_warnings(record=True) as caught_warnings: torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) reissue_pt_warnings(caught_warnings) @@ -1795,7 +1828,14 @@ def _load_optimizer_and_scheduler(self, checkpoint): # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init return - if os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME)) and os.path.isfile( + if self.args.smp_load_partial: + import glob + # SMP partial checkpoints are in {filename}_{pp_rank()}_{tp_rank()} or {filename}_{pp_rank()}_{tp_rank()}_{rdp_rank()} format. + checkpoint_file_exists = glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*") + else: + checkpoint_file_exists = os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME)) + + if checkpoint_file_exists and os.path.isfile( os.path.join(checkpoint, SCHEDULER_NAME) ): # Load in optimizer and scheduler states @@ -1813,9 +1853,14 @@ def _load_optimizer_and_scheduler(self, checkpoint): self.lr_scheduler.load_state_dict(lr_scheduler_state) else: map_location = "cpu" if is_sagemaker_mp_enabled() else self.args.device - self.optimizer.load_state_dict( - torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location) - ) + if is_sagemaker_mp_enabled(): + def opt_load_hook(mod, opt): + opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=self.args.smp_load_partial)) + self.model_wrapped.register_post_step_hook(opt_load_hook) + else: + self.optimizer.load_state_dict( + torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location) + ) with warnings.catch_warnings(record=True) as caught_warnings: self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME))) reissue_pt_warnings(caught_warnings) @@ -2087,7 +2132,10 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa self._save_tpu(output_dir) elif is_sagemaker_mp_enabled(): # Calling the state_dict needs to be done on the wrapped model and on all processes. - state_dict = self.model_wrapped.state_dict() + if self.args.smp_save_partial: + state_dict = self.model_wrapped.local_state_dict() + else: + state_dict = self.model_wrapped.state_dict() if self.args.should_save: self._save(output_dir, state_dict=state_dict) elif ( @@ -2175,9 +2223,15 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") if state_dict is None: state_dict = self.model.state_dict() - torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) + if self.args.smp_save_partial: + smp.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME), partial=self.args.smp_save_partial) + else: + torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) else: - self.model.save_pretrained(output_dir, state_dict=state_dict) + if self.args.smp_save_partial: + self.model.save_pretrained(output_dir, save_function=smp.save, state_dict=state_dict) + else: + self.model.save_pretrained(output_dir, state_dict=state_dict) if self.tokenizer is not None: self.tokenizer.save_pretrained(output_dir) @@ -2592,7 +2646,10 @@ def _nested_gather(self, tensors, name=None): name = "nested_gather" tensors = nested_xla_mesh_reduce(tensors, name) elif is_sagemaker_mp_enabled(): - tensors = smp_gather(tensors) + if smp.state.cfg.ddp: + tensors = distributed_concat(tensors.cuda()) + else: + tensors = smp_gather(tensors) elif self.args.local_rank != -1: tensors = distributed_concat(tensors) return tensors diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index adaf50da8ffe..1da2c5970f3c 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -423,6 +423,12 @@ class TrainingArguments: include_inputs_for_metrics (`bool`, *optional*, defaults to `False`): Whether or not the inputs will be passed to the `compute_metrics` function. This is intended for metrics that need inputs, predictions and references for scoring calculation in Metric class. + smp_save_partial (`bool`, *optional*, defaults to `False`): + If True, saves checkpoints partially. `"smp_save_partial"` can only be used with Sagemaker Model Parallel library. + smp_load_partial (`bool`, *optional*, defaults to `False`): + If True, loads partial checkpoints. `"smp_load_partial"` can only be used with Sagemaker Model Parallel library. + smp_tensor_parallel_full_model (`bool`, *optional*, defaults to `False`): + If True, apply tensor paralellism for the full model. `"smp_tensor_parallel_full_model"` can only be used with Sagemaker Model Parallel library. """ output_dir: str = field( @@ -750,6 +756,10 @@ class TrainingArguments: include_inputs_for_metrics: bool = field( default=False, metadata={"help": "Whether or not the inputs will be passed to the `compute_metrics` function."} ) + smp_save_partial: bool = field(default=False, metadata={"help": "Save checkpoints partially for SMP."}) + smp_load_partial: bool = field(default=False, metadata={"help": "Load partial checkpoints for SMP."}) + smp_tensor_parallel_full_model: bool = field(default=False, metadata={"help": "Enables tensor parallelism for full model in SMP."}) + # Deprecated arguments fp16_backend: str = field( default="auto", @@ -985,6 +995,19 @@ def __post_init__(self): FutureWarning, ) + if (self.smp_save_partial or self.smp_load_partial) and not is_sagemaker_mp_enabled(): + raise ValueError(f"smp_save_partial and smp_load_partial can only be used with Sagemaker Model Parallel library.") + + if (is_sagemaker_mp_enabled() and not self.smp_save_partial and smp.state.cfg.shard_optimizer_state ): + warnings.warn("Optimizer sharding can only be used with partial checkpointing. Switching to partial checkpointing.") + self.smp_save_partial = True + + if (is_sagemaker_mp_enabled() and self.smp_save_partial and self.load_best_model_at_end ): + self.smp_load_partial = True + + if (is_sagemaker_mp_enabled() and (not self.smp_save_partial) and (self.save_strategy != IntervalStrategy.NO)): + warnings.warn("Saving weights but not the optimizer state.") + def __str__(self): self_as_dict = asdict(self) @@ -1218,7 +1241,10 @@ def should_save(self): return self.local_process_index == 0 else: if is_sagemaker_mp_enabled(): - return smp.rank() == 0 + if self.smp_save_partial: + return smp.rdp_rank() == 0 + else: + return smp.rank() == 0 else: return self.process_index == 0 From 8ebcbd3d4556ca92ed54f852e66142aa9fae840f Mon Sep 17 00:00:00 2001 From: Cavdar Date: Wed, 13 Apr 2022 16:34:31 -0700 Subject: [PATCH 2/2] make style changes --- src/transformers/modeling_utils.py | 1 + src/transformers/trainer.py | 25 ++++++++++++++++++------- src/transformers/training_args.py | 27 ++++++++++++++++++--------- 3 files changed, 37 insertions(+), 16 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 9bda8bc7504a..6c76f83dd292 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1362,6 +1362,7 @@ def save_pretrained( del state_dict[ignore_key] from .utils import is_sagemaker_mp_enabled + if is_sagemaker_mp_enabled(): # Do not shard checkpoints when sagemaker model parallel is enabled output_model_file = os.path.join(save_directory, WEIGHTS_NAME) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 1ba11a0f2896..1dd622310475 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1171,10 +1171,10 @@ def train( if resume_from_checkpoint is None: raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})") - if resume_from_checkpoint is not None: if self.args.smp_load_partial: import glob + # SMP partial checkpoints are in {filename}_{pp_rank()}_{tp_rank()} or {filename}_{pp_rank()}_{tp_rank()}_{rdp_rank()} format. checkpoint_file_exists = glob.glob(os.path.join(resume_from_checkpoint, WEIGHTS_NAME) + "_*") else: @@ -1288,7 +1288,9 @@ def train( model = self._wrap_model(self.model_wrapped) if resume_from_checkpoint is not None and is_sagemaker_mp_enabled(): - state_dict = smp.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), partial=self.args.smp_load_partial) + state_dict = smp.load( + os.path.join(resume_from_checkpoint, WEIGHTS_NAME), partial=self.args.smp_load_partial + ) model.load_state_dict(state_dict) # for the rest of this function `model` is the outside model, whether it was wrapped or not @@ -1563,6 +1565,7 @@ def train( best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME) if self.args.smp_load_partial: import glob + # SMP partial checkpoints are in {filename}_{pp_rank()}_{tp_rank()} or {filename}_{pp_rank()}_{tp_rank()}_{rdp_rank()} format. smp_checkpoint_file_exists = glob.glob(best_model_path + "_*") @@ -1750,7 +1753,12 @@ def _save_checkpoint(self, model, trial, metrics=None): else: opt_state_dict = self.optimizer.state_dict() if self.args.should_save or smp.state.cfg.shard_optimizer_state: - smp.save(opt_state_dict, os.path.join(output_dir, OPTIMIZER_NAME), partial=self.args.smp_save_partial, v3=smp.state.cfg.shard_optimizer_state) + smp.save( + opt_state_dict, + os.path.join(output_dir, OPTIMIZER_NAME), + partial=self.args.smp_save_partial, + v3=smp.state.cfg.shard_optimizer_state, + ) # Save it and the scheduler on the main process if self.args.should_save: with warnings.catch_warnings(record=True) as caught_warnings: @@ -1830,14 +1838,13 @@ def _load_optimizer_and_scheduler(self, checkpoint): if self.args.smp_load_partial: import glob + # SMP partial checkpoints are in {filename}_{pp_rank()}_{tp_rank()} or {filename}_{pp_rank()}_{tp_rank()}_{rdp_rank()} format. checkpoint_file_exists = glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*") else: checkpoint_file_exists = os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME)) - if checkpoint_file_exists and os.path.isfile( - os.path.join(checkpoint, SCHEDULER_NAME) - ): + if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)): # Load in optimizer and scheduler states if is_torch_tpu_available(): # On TPU we have to take some extra precautions to properly load the states on the right device. @@ -1854,8 +1861,12 @@ def _load_optimizer_and_scheduler(self, checkpoint): else: map_location = "cpu" if is_sagemaker_mp_enabled() else self.args.device if is_sagemaker_mp_enabled(): + def opt_load_hook(mod, opt): - opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=self.args.smp_load_partial)) + opt.load_state_dict( + smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=self.args.smp_load_partial) + ) + self.model_wrapped.register_post_step_hook(opt_load_hook) else: self.optimizer.load_state_dict( diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 1da2c5970f3c..eaf69b99c200 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -424,11 +424,14 @@ class TrainingArguments: Whether or not the inputs will be passed to the `compute_metrics` function. This is intended for metrics that need inputs, predictions and references for scoring calculation in Metric class. smp_save_partial (`bool`, *optional*, defaults to `False`): - If True, saves checkpoints partially. `"smp_save_partial"` can only be used with Sagemaker Model Parallel library. + If True, saves checkpoints partially. `"smp_save_partial"` can only be used with Sagemaker Model Parallel + library. smp_load_partial (`bool`, *optional*, defaults to `False`): - If True, loads partial checkpoints. `"smp_load_partial"` can only be used with Sagemaker Model Parallel library. + If True, loads partial checkpoints. `"smp_load_partial"` can only be used with Sagemaker Model Parallel + library. smp_tensor_parallel_full_model (`bool`, *optional*, defaults to `False`): - If True, apply tensor paralellism for the full model. `"smp_tensor_parallel_full_model"` can only be used with Sagemaker Model Parallel library. + If True, apply tensor paralellism for the full model. `"smp_tensor_parallel_full_model"` can only be used + with Sagemaker Model Parallel library. """ output_dir: str = field( @@ -758,7 +761,9 @@ class TrainingArguments: ) smp_save_partial: bool = field(default=False, metadata={"help": "Save checkpoints partially for SMP."}) smp_load_partial: bool = field(default=False, metadata={"help": "Load partial checkpoints for SMP."}) - smp_tensor_parallel_full_model: bool = field(default=False, metadata={"help": "Enables tensor parallelism for full model in SMP."}) + smp_tensor_parallel_full_model: bool = field( + default=False, metadata={"help": "Enables tensor parallelism for full model in SMP."} + ) # Deprecated arguments fp16_backend: str = field( @@ -996,16 +1001,20 @@ def __post_init__(self): ) if (self.smp_save_partial or self.smp_load_partial) and not is_sagemaker_mp_enabled(): - raise ValueError(f"smp_save_partial and smp_load_partial can only be used with Sagemaker Model Parallel library.") + raise ValueError( + f"smp_save_partial and smp_load_partial can only be used with Sagemaker Model Parallel library." + ) - if (is_sagemaker_mp_enabled() and not self.smp_save_partial and smp.state.cfg.shard_optimizer_state ): - warnings.warn("Optimizer sharding can only be used with partial checkpointing. Switching to partial checkpointing.") + if is_sagemaker_mp_enabled() and not self.smp_save_partial and smp.state.cfg.shard_optimizer_state: + warnings.warn( + "Optimizer sharding can only be used with partial checkpointing. Switching to partial checkpointing." + ) self.smp_save_partial = True - if (is_sagemaker_mp_enabled() and self.smp_save_partial and self.load_best_model_at_end ): + if is_sagemaker_mp_enabled() and self.smp_save_partial and self.load_best_model_at_end: self.smp_load_partial = True - if (is_sagemaker_mp_enabled() and (not self.smp_save_partial) and (self.save_strategy != IntervalStrategy.NO)): + if is_sagemaker_mp_enabled() and (not self.smp_save_partial) and (self.save_strategy != IntervalStrategy.NO): warnings.warn("Saving weights but not the optimizer state.") def __str__(self):