Skip to content

Commit

Permalink
[auto_parallel] Add checkpoint convertor (PaddlePaddle#8847)
Browse files Browse the repository at this point in the history
* Add the checkpoint conversion module
  • Loading branch information
xingmingyyj authored and Mangodadada committed Sep 10, 2024
1 parent 78c2863 commit 4fedd09
Show file tree
Hide file tree
Showing 3 changed files with 1,160 additions and 15 deletions.
42 changes: 27 additions & 15 deletions paddlenlp/trainer/auto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
has_length,
speed_metrics,
)
from .utils.ckpt_converter import CheckpointConverter
from .utils.helper import distributed_file, distributed_isfile # nested_truncate,

try:
Expand Down Expand Up @@ -720,20 +721,16 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
)
)

ckpt_path = os.path.join(resume_from_checkpoint, DIST_CKPT_PATH)

if not os.path.isdir(ckpt_path):
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")

if self.args.to_static:
opt_state_dict = {
model_state_dict = {
key: value
for key, value in self.model_wrapped.state_dict("opt").items()
for key, value in self.model_wrapped.state_dict("param").items()
if not any(keyword in key for keyword in FREE_SVAE_LOAD_KEY_PATTERNS)
}
state_dict = {
MODEL_NAME: self.model_wrapped.state_dict("param"),
OPTIMIZER_NAME: opt_state_dict,
optim_state_dict = {
key: value
for key, value in self.model_wrapped.state_dict("opt").items()
if not any(keyword in key for keyword in FREE_SVAE_LOAD_KEY_PATTERNS)
}
else:
model_state_dict = self.model_wrapped.state_dict()
Expand All @@ -746,12 +743,27 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
optim_state_dict = self.optimizer.state_dict()
optim_state_dict.pop("LR_Scheduler", None)

state_dict = {
MODEL_NAME: model_state_dict,
OPTIMIZER_NAME: optim_state_dict,
}
state_dict = {
MODEL_NAME: model_state_dict,
OPTIMIZER_NAME: optim_state_dict,
}

self._load_ckpt_func(state_dict, ckpt_path)
parameter_to_structured_name = {}
if self.args.to_static:
parameter_to_structured_name = self.model_wrapped._parameter_to_structured_name
else:
for state_name, state_value in self.model_wrapped.state_dict().items():
parameter_to_structured_name[state_value.name] = state_name

if self.args.auto_parallel_resume_form_hybrid_parallel:
CheckpointConverter(
resume_from_checkpoint, state_dict, parameter_to_structured_name
).load_from_hybrid_parallel_checkpoint()
else:
ckpt_path = os.path.join(resume_from_checkpoint, DIST_CKPT_PATH)
if not os.path.isdir(ckpt_path):
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
self._load_ckpt_func(state_dict, ckpt_path)

# release memory
del state_dict
6 changes: 6 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,8 @@ class TrainingArguments:
The path to a folder with a valid checkpoint for your model. This argument is not directly used by
[`Trainer`], it's intended to be used by your training/evaluation scripts instead. See the [example
scripts](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples) for more details.
auto_parallel_resume_form_hybrid_parallel (`bool`, *optional*):
Wether hybrid paralle checkpoints be loaded in auto parallel mode.
flatten_param_grads (`bool`, *optional*):
Whether use flatten_param_grads method in optimizer, only used on NPU devices. Default is `False`.
skip_profile_timer (`bool`, *optional*):
Expand Down Expand Up @@ -783,6 +785,10 @@ class TrainingArguments:
default=None,
metadata={"help": "The path to a folder with a valid checkpoint for your model."},
)
auto_parallel_resume_form_hybrid_parallel: Optional[bool] = field(
default=False,
metadata={"help": "Wether hybrid paralle checkpoints be loaded in auto parallel mode."},
)
skip_memory_metrics: bool = field(
default=True, metadata={"help": "Whether or not to skip adding of memory profiler reports to metrics."}
)
Expand Down
Loading

0 comments on commit 4fedd09

Please sign in to comment.