diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index eee4ddc5ce53..307adc7b60ec 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -364,7 +364,17 @@ def __init__( ) self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK) - self._save_ckpt_func = dist.save_state_dict if self.args.enable_auto_parallel else paddle.save + def _save_ckpt_func(state_dict, path, signal_path=None): + if self.args.enable_auto_parallel: + dist.save_state_dict(state_dict, path) + else: + paddle.save(state_dict, path) + + if signal_path is not None: + with open(signal_path, mode="w+") as f: + f.write("1") + + self._save_ckpt_func = _save_ckpt_func self._load_ckpt_func = dist.load_state_dict if self.args.enable_auto_parallel else paddle.load if self.args.use_async_save: self._async_optimizer_saver = AsyncSaver() @@ -2310,7 +2320,9 @@ def _save_checkpoint(self, model, metrics=None): self._save_ckpt_func( self._filter_moe_no_sync_optimizer_params(), os.path.join(output_dir, optimizer_name), + saved_signal_path, ) + else: state_dict = self.optimizer.state_dict() save_path = os.path.join(output_dir, optimizer_name) @@ -2320,9 +2332,7 @@ def _save_checkpoint(self, model, metrics=None): state_dict, save_path, saved_signal_path=saved_signal_path ) else: - self._save_ckpt_func(state_dict, save_path) - with open(saved_signal_path, mode="w+") as f: - f.write("1") + self._save_ckpt_func(state_dict, save_path, saved_signal_path) else: if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config: global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1 @@ -2341,10 +2351,16 @@ def _save_checkpoint(self, model, metrics=None): else: if self.args.data_parallel_rank > 0 and self.args.use_expert_parallel: self._save_ckpt_func( - self._filter_moe_no_sync_optimizer_params(), os.path.join(output_dir, OPTIMIZER_NAME) + self._filter_moe_no_sync_optimizer_params(), + os.path.join(output_dir, optimizer_name), + saved_signal_path, ) else: - self._save_ckpt_func(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) + self._save_ckpt_func( + self.optimizer.state_dict(), + os.path.join(output_dir, optimizer_name), + saved_signal_path, + ) # FIXME: maybe only save one copy paddle.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) diff --git a/paddlenlp/trainer/utils/sharding_io.py b/paddlenlp/trainer/utils/sharding_io.py index abb34cd1f1e6..be4ab9477ce2 100644 --- a/paddlenlp/trainer/utils/sharding_io.py +++ b/paddlenlp/trainer/utils/sharding_io.py @@ -83,6 +83,9 @@ def filter_sharded_params(state_dict, optimizer, sharding_group): for (k, v) in state_dict.items(): if v.name in filtered_parameters: filtered_state_dict[k] = v + else: + if sharding_rank == 0: + filtered_state_dict[k] = v return filtered_state_dict