Skip to content

Commit

Permalink
[Trainer] fix save_model (#9287)
Browse files Browse the repository at this point in the history
* [Unified Checkpoint] Support expert parallel (#9055)

* update code

* [Unified Checkpoint] Fix generation config save (#9223)

* [Unified Checkpoint] update async_save_info in develop (#9173)

* [Unified Checkpoint] update async save logic (#9274)

* update async save signal

* fix async save hang

* bug fix

* bug fix

* [Trainer] fix save_model (#9286)

* bug fix

* bug fix

---------

Co-authored-by: Weiguo Zhu <[email protected]>
  • Loading branch information
DesmonDay and DrownFish19 authored Oct 17, 2024
1 parent 036e1cd commit 8da4248
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2289,7 +2289,6 @@ def save_model(
self,
output_dir: Optional[str] = None,
merge_tensor_parallel: Optional[bool] = False,
signal_dir: Optional[str] = None,
):
"""
Will save the model, so you can reload it using `from_pretrained()`.
Expand All @@ -2300,14 +2299,16 @@ def save_model(
if output_dir is None:
output_dir = self.args.output_dir

if signal_dir is None:
if PREFIX_CHECKPOINT_DIR in output_dir:
signal_dir = os.path.join(self.args.output_signal_dir, os.path.split(output_dir)[-1])
else:
signal_dir = self.args.output_signal_dir

if ShardingOption.FULL_SHARD in self.args.sharding:
self.model_wrapped.get_all_parameters(convert2cpu=True)

if self.args.should_save_model_state:
self._save(output_dir=output_dir, merge_tensor_parallel=merge_tensor_parallel, signal_dir=signal_dir)
self._save(output_dir=output_dir, merge_tensor_parallel=merge_tensor_parallel)
else:
if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config:
os.makedirs(signal_dir, exist_ok=True)
Expand Down Expand Up @@ -2367,11 +2368,11 @@ def _save_checkpoint(self, model, metrics=None):
signal_dir = os.path.join(run_signal_dir, checkpoint_folder)

if isinstance(self.model, LoRAModel) and (self.model.quantized or self.args.pipeline_parallel_degree > 1):
self.save_model(output_dir, False, signal_dir)
self.save_model(output_dir)
elif isinstance(self.model, LoRAModel) or isinstance(self.model, PrefixModelForCausalLM):
self.save_model(output_dir, True, signal_dir)
self.save_model(output_dir, True)
else:
self.save_model(output_dir, False, signal_dir)
self.save_model(output_dir)

# only save model state dict, ignore optimizer and scheduler
if not self.args.ignore_save_lr_and_optim:
Expand Down Expand Up @@ -2588,15 +2589,16 @@ def _save(
output_dir: Optional[str] = None,
state_dict=None,
merge_tensor_parallel=False,
signal_dir: Optional[str] = None,
):
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving model checkpoint to {output_dir}")

# signal_dir is used for asynchronous saving situations.
signal_dir = self.args.output_signal_dir
if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config:
signal_dir = signal_dir if signal_dir is not None else self.args.output_signal_dir
if PREFIX_CHECKPOINT_DIR in output_dir:
signal_dir = os.path.join(signal_dir, os.path.split(output_dir)[-1])
os.makedirs(signal_dir, exist_ok=True)
logger.info(f"Saving model checkpoint finish signal to {signal_dir}")

Expand Down

0 comments on commit 8da4248

Please sign in to comment.