Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【auto_parallel】Add checkpoint convertor #8847

Merged
merged 31 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 27 additions & 15 deletions paddlenlp/trainer/auto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
has_length,
speed_metrics,
)
from .utils.ckpt_converter import CheckpointConverter

Check warning on line 42 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L42

Added line #L42 was not covered by tests
from .utils.helper import distributed_file, distributed_isfile # nested_truncate,

try:
Expand Down Expand Up @@ -695,20 +696,16 @@
)
)

ckpt_path = os.path.join(resume_from_checkpoint, DIST_CKPT_PATH)

if not os.path.isdir(ckpt_path):
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")

if self.args.to_static:
opt_state_dict = {
model_state_dict = {

Check warning on line 700 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L700

Added line #L700 was not covered by tests
key: value
for key, value in self.model_wrapped.state_dict("opt").items()
for key, value in self.model_wrapped.state_dict("param").items()

Check warning on line 702 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L702

Added line #L702 was not covered by tests
if not any(keyword in key for keyword in FREE_SVAE_LOAD_KEY_PATTERNS)
}
state_dict = {
MODEL_NAME: self.model_wrapped.state_dict("param"),
OPTIMIZER_NAME: opt_state_dict,
optim_state_dict = {
key: value
for key, value in self.model_wrapped.state_dict("opt").items()
if not any(keyword in key for keyword in FREE_SVAE_LOAD_KEY_PATTERNS)

Check warning on line 708 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L705-L708

Added lines #L705 - L708 were not covered by tests
}
else:
model_state_dict = self.model_wrapped.state_dict()
Expand All @@ -721,12 +718,27 @@
optim_state_dict = self.optimizer.state_dict()
optim_state_dict.pop("LR_Scheduler", None)

state_dict = {
MODEL_NAME: model_state_dict,
OPTIMIZER_NAME: optim_state_dict,
}
state_dict = {
MODEL_NAME: model_state_dict,
OPTIMIZER_NAME: optim_state_dict,
}

Check warning on line 724 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L724

Added line #L724 was not covered by tests

self._load_ckpt_func(state_dict, ckpt_path)
parameter_to_structured_name = {}
if self.args.to_static:
parameter_to_structured_name = self.model_wrapped._parameter_to_structured_name

Check warning on line 728 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L726-L728

Added lines #L726 - L728 were not covered by tests
else:
for state_name, state_value in self.model_wrapped.state_dict().items():
parameter_to_structured_name[state_value.name] = state_name

Check warning on line 732 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L730-L732

Added lines #L730 - L732 were not covered by tests
if self.args.auto_parallel_resume_form_hybrid_parallel:
CheckpointConverter(

Check warning on line 734 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L734

Added line #L734 was not covered by tests
resume_from_checkpoint, state_dict, parameter_to_structured_name
).load_from_hybrid_parallel_checkpoint()
else:
ckpt_path = os.path.join(resume_from_checkpoint, DIST_CKPT_PATH)
if not os.path.isdir(ckpt_path):
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
self._load_ckpt_func(state_dict, ckpt_path)

Check warning on line 741 in paddlenlp/trainer/auto_trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/auto_trainer.py#L736-L741

Added lines #L736 - L741 were not covered by tests

# release memory
del state_dict
6 changes: 6 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,8 @@ class TrainingArguments:
The path to a folder with a valid checkpoint for your model. This argument is not directly used by
[`Trainer`], it's intended to be used by your training/evaluation scripts instead. See the [example
scripts](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples) for more details.
auto_parallel_resume_form_hybrid_parallel (`bool`, *optional*):
Wether hybrid paralle checkpoints be loaded in auto parallel mode.
flatten_param_grads (`bool`, *optional*):
Whether use flatten_param_grads method in optimizer, only used on NPU devices. Default is `False`.
skip_profile_timer (`bool`, *optional*):
Expand Down Expand Up @@ -770,6 +772,10 @@ class TrainingArguments:
default=None,
metadata={"help": "The path to a folder with a valid checkpoint for your model."},
)
auto_parallel_resume_form_hybrid_parallel: Optional[bool] = field(
default=False,
metadata={"help": "Wether hybrid paralle checkpoints be loaded in auto parallel mode."},
)
skip_memory_metrics: bool = field(
default=True, metadata={"help": "Whether or not to skip adding of memory profiler reports to metrics."}
)
Expand Down
Loading
Loading