Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Trainer] fix save_model #9287

Merged
18 changes: 10 additions & 8 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2289,7 +2289,6 @@ def save_model(
self,
output_dir: Optional[str] = None,
merge_tensor_parallel: Optional[bool] = False,
signal_dir: Optional[str] = None,
):
"""
Will save the model, so you can reload it using `from_pretrained()`.
Expand All @@ -2300,14 +2299,16 @@ def save_model(
if output_dir is None:
output_dir = self.args.output_dir

if signal_dir is None:
if PREFIX_CHECKPOINT_DIR in output_dir:
signal_dir = os.path.join(self.args.output_signal_dir, os.path.split(output_dir)[-1])
else:
signal_dir = self.args.output_signal_dir

if ShardingOption.FULL_SHARD in self.args.sharding:
self.model_wrapped.get_all_parameters(convert2cpu=True)

if self.args.should_save_model_state:
self._save(output_dir=output_dir, merge_tensor_parallel=merge_tensor_parallel, signal_dir=signal_dir)
self._save(output_dir=output_dir, merge_tensor_parallel=merge_tensor_parallel)
else:
if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config:
os.makedirs(signal_dir, exist_ok=True)
Expand Down Expand Up @@ -2367,11 +2368,11 @@ def _save_checkpoint(self, model, metrics=None):
signal_dir = os.path.join(run_signal_dir, checkpoint_folder)

if isinstance(self.model, LoRAModel) and (self.model.quantized or self.args.pipeline_parallel_degree > 1):
self.save_model(output_dir, False, signal_dir)
self.save_model(output_dir)
elif isinstance(self.model, LoRAModel) or isinstance(self.model, PrefixModelForCausalLM):
self.save_model(output_dir, True, signal_dir)
self.save_model(output_dir, True)
else:
self.save_model(output_dir, False, signal_dir)
self.save_model(output_dir)

# only save model state dict, ignore optimizer and scheduler
if not self.args.ignore_save_lr_and_optim:
Expand Down Expand Up @@ -2588,15 +2589,16 @@ def _save(
output_dir: Optional[str] = None,
state_dict=None,
merge_tensor_parallel=False,
signal_dir: Optional[str] = None,
):
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving model checkpoint to {output_dir}")

# signal_dir is used for asynchronous saving situations.
signal_dir = self.args.output_signal_dir
if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config:
signal_dir = signal_dir if signal_dir is not None else self.args.output_signal_dir
if PREFIX_CHECKPOINT_DIR in output_dir:
signal_dir = os.path.join(signal_dir, os.path.split(output_dir)[-1])
os.makedirs(signal_dir, exist_ok=True)
logger.info(f"Saving model checkpoint finish signal to {signal_dir}")

Expand Down
Loading