diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 06e5a9e91d69..2b8cb013f6c3 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2850,7 +2850,8 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use" " zero_to_fp32.py to recover weights" ) - self._save(output_dir, state_dict={}) + if self.args.should_save: + self._save(output_dir, state_dict={}) # remove the dummy state_dict remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]) self.model_wrapped.save_checkpoint(output_dir)