Skip to content

Commit

Permalink
Refine checkpoint converter (#9001)
Browse files Browse the repository at this point in the history
* refine
  • Loading branch information
zhangbo9674 authored Aug 27, 2024
1 parent 48820cb commit a0609e8
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 14 deletions.
4 changes: 3 additions & 1 deletion paddlenlp/trainer/auto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,8 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):
)

if self.args.to_static:
if self.model_wrapped._mode is None:
self.model_wrapped.train()
model_state_dict = {
key: value
for key, value in self.model_wrapped.state_dict("param").items()
Expand Down Expand Up @@ -757,7 +759,7 @@ def _load_from_checkpoint(self, resume_from_checkpoint=None):

if self.args.auto_parallel_resume_form_hybrid_parallel:
CheckpointConverter(
resume_from_checkpoint, state_dict, parameter_to_structured_name
resume_from_checkpoint, state_dict, parameter_to_structured_name, self.args
).load_from_hybrid_parallel_checkpoint()
else:
ckpt_path = os.path.join(resume_from_checkpoint, DIST_CKPT_PATH)
Expand Down
37 changes: 24 additions & 13 deletions paddlenlp/trainer/utils/ckpt_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,22 @@
MODEL_WEIGHT_SUFFIX = ".pdparams"
OPTIMIZER_WEIGHT_SUFFIX = ".pdopt"
SCHEDULER_NAME = "scheduler.pdparams"
SCALAR_NAME = "scalar.pdparams"
MODEL_META_FILE_NAME = "model_meta.json"
OPTIMIZER_STATE_NAME_SUFFIX = [".moment1", ".moment2", ".beta1_pow_acc", ".beta2_pow_acc", ".master_weight"]
MODEL_STATE_FILE_MIN_SIZE = 512


class CheckpointConverter:
def __init__(self, hybrid_parallel_ckpt_path, state_dict, parameter_to_structured_name, patch_dict=None):
def __init__(
self, hybrid_parallel_ckpt_path, state_dict, parameter_to_structured_name, trainging_args=None, patch_dict=None
):
self.use_dist = True if paddle.distributed.get_world_size() > 1 else False
self.path = hybrid_parallel_ckpt_path

if trainging_args.ignore_load_lr_and_optim:
state_dict.pop("optimizer")

self.auto_parallel_state_dict = self.flatten_state_dict(state_dict)
self.parameter_to_structured_name = self.gather_global_object(parameter_to_structured_name)
model_state_global_shape = {}
Expand Down Expand Up @@ -74,9 +81,9 @@ def __init__(self, hybrid_parallel_ckpt_path, state_dict, parameter_to_structure
for k, v in self.auto_parallel_state_dict.items():
if k in self.patch_dict:
del_keys.append(k)

for k in del_keys:
self.auto_parallel_state_dict[self.patch_dict[k]] = self.auto_parallel_state_dict[k]
for k in del_keys:
self.auto_parallel_state_dict.pop(k)

flags = [
Expand Down Expand Up @@ -896,25 +903,26 @@ def rename(old_name, parameter_to_structured_name):
return renamed_state_dict

def rename_using_optimizer_state_order(self, model_state_keys, optimizer_state_dict):

name_mapping = {}
suffix_bucket = {}
assert len(optimizer_state_dict) % len(model_state_keys) == 0
for suffix in OPTIMIZER_STATE_NAME_SUFFIX:
suffix_bucket[suffix] = []
for satte_name, satte_value in optimizer_state_dict.items():
if "moment1" in satte_name:
suffix_bucket[".moment1"].append(satte_name)
elif "moment2" in satte_name:
suffix_bucket[".moment2"].append(satte_name)
elif "beta1_pow_acc" in satte_name:
suffix_bucket[".beta1_pow_acc"].append(satte_name)
elif "beta2_pow_acc" in satte_name:
suffix_bucket[".beta2_pow_acc"].append(satte_name)
for opt_name, opt_value in optimizer_state_dict.items():
if "moment1" in opt_name:
suffix_bucket[".moment1"].append(opt_name)
elif "moment2" in opt_name:
suffix_bucket[".moment2"].append(opt_name)
elif "beta1_pow_acc" in opt_name:
suffix_bucket[".beta1_pow_acc"].append(opt_name)
elif "beta2_pow_acc" in opt_name:
suffix_bucket[".beta2_pow_acc"].append(opt_name)
else:
suffix_bucket[".master_weight"].append(satte_name)
suffix_bucket[".master_weight"].append(opt_name)

for suffix, old_names in suffix_bucket.items():
if len(old_names) == 0:
continue
assert len(old_names) == len(model_state_keys)
for i in range(len(old_names)):
name_mapping[old_names[i]] = model_state_keys[i] + suffix
Expand Down Expand Up @@ -1011,6 +1019,9 @@ def get_local_checkpoint_file_names(self):
cur_rank_optimizer_state_file_names.append(file_name)
if SCHEDULER_NAME in cur_rank_model_state_file_names:
cur_rank_model_state_file_names.remove(SCHEDULER_NAME)
if SCALAR_NAME in cur_rank_model_state_file_names:
cur_rank_model_state_file_names.remove(SCALAR_NAME)

return cur_rank_model_state_file_names, cur_rank_optimizer_state_file_names

def get_distribution_rank_from_file_name(self, file_name):
Expand Down
160 changes: 160 additions & 0 deletions scripts/distribute/ci_case_auto.sh
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ function llama_case_list_auto() {
llama_static_auto_recompute_bs16_fp16_DP2-MP2-PP2-VPP2-Sharding2_stage2

llama_align_dygraph_dy2st_auto_bs2_bf16_DP2-MP1-PP1
llama_convert_hybrid_ckpt_to_auto_parallel_bs2_fp32_DP2-MP1-PP1
}

function llm_gpt_case_list_auto() {
Expand Down Expand Up @@ -1062,6 +1063,165 @@ function llama_align_dygraph_dy2st_auto_bs2_bf16_DP2-MP1-PP1() {
echo "=========== $FUNCNAME run end ==========="
}

function llama_convert_hybrid_ckpt_to_auto_parallel_bs2_fp32_DP2-MP1-PP1() {
echo "=========== $FUNCNAME run begin ==========="
export PYTHONPATH=$root_path/:$PYTHONPATH
export FLAGS_call_stack_level=3
export NVIDIA_TF32_OVERRIDE=0
export FLAGS_enable_pir_api=1
export FLAGS_max_inplace_grad_add=3

echo "---- run hybrid and save ckpt ----"
dy_task_name="llama_hybrid_ckpt_bs2_fp32_DP2-MP1-PP1"
dy_case_out_dir="dy_output/$dy_task_name"
dy_case_log_dir="dy_output/$dy_task_name""_log"
rm -rf $dy_case_out_dir
rm -rf $dy_case_log_dir

python -u -m paddle.distributed.launch \
--gpus "0,1" \
--log_dir $dy_case_log_dir \
../../run_pretrain.py \
--model_name_or_path "facebook/llama-7b" \
--tokenizer_name_or_path "facebook/llama-7b" \
--input_dir "./data" \
--output_dir $dy_case_out_dir \
--split 949,50,1 \
--weight_decay 0.01 \
--warmup_ratio 0.01 \
--warmup_steps 30 \
--max_grad_norm 0.0 \
--learning_rate 3e-05 \
--min_learning_rate 3e-06 \
--max_steps 5 \
--logging_steps 1 \
--eval_steps 1000 \
--save_steps 3 \
--continue_training 0 \
--do_train true \
--do_eval false \
--do_predict false \
--disable_tqdm true \
--skip_profile_timer true \
--save_total_limit 2 \
--device gpu \
--disable_tqdm true \
--dataloader_num_workers 1 \
--distributed_dataloader 0 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 1 \
--per_device_eval_batch_size 2 \
--recompute false \
--recompute_use_reentrant true \
--recompute_granularity full \
--pp_recompute_interval 0 \
--bf16 0 \
--fp16_opt_level "O2" \
--amp_custom_black_list "reduce_sum" "c_softmax_with_cross_entropy" \
--amp_custom_white_list "lookup_table" "lookup_table_v2" \
--amp_master_grad false \
--enable_linear_fused_grad_add false \
--fuse_attention_ffn true \
--fuse_attention_qkv false \
--fuse_sequence_parallel_allreduce false \
--use_flash_attention 0 \
--use_fused_rope false \
--use_fused_rms_norm 0 \
--max_seq_length 4096 \
--sep_parallel_degree 1 \
--sequence_parallel false \
--pipeline_parallel_degree 1 \
--sharding_parallel_degree 1 \
--tensor_parallel_degree 1 \
--virtual_pp_degree 1 \
--sharding "" \
--to_static 0 \
--num_hidden_layers 2 \
>>${log_path}/$FUNCNAME 2>&1
dy_loss=`cat $dy_case_log_dir/workerlog.0 | grep 'global_step: 4' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
dy_ips=-1
dy_mem=-1
echo "hybrid result: loss=$dy_loss ips=$dy_ips mem=$dy_mem"

echo "---- run auto parallel resueme from hybrid ckpt ----"
auto_task_name="llama_auto_parallel_bs2_fp32_DP2-MP1-PP1"
auto_case_out_dir="auto_output/$auto_task_name"
auto_case_log_dir="auto_output/$auto_task_name""_log"
rm -rf $auto_case_out_dir
rm -rf $auto_case_log_dir

python -u -m paddle.distributed.launch \
--gpus "0,1" \
--log_dir $auto_case_log_dir \
run_pretrain_auto.py \
--model_name_or_path "facebook/llama-7b" \
--tokenizer_name_or_path "facebook/llama-7b" \
--input_dir "./data" \
--output_dir $auto_case_out_dir \
--split 949,50,1 \
--weight_decay 0.01 \
--warmup_ratio 0.01 \
--warmup_steps 30 \
--max_grad_norm 0.0 \
--learning_rate 3e-05 \
--min_learning_rate 3e-06 \
--max_steps 4 \
--logging_steps 1 \
--eval_steps 1000 \
--save_steps 1000 \
--continue_training 0 \
--do_train true \
--do_eval false \
--do_predict false \
--disable_tqdm true \
--skip_profile_timer true \
--save_total_limit 2 \
--device gpu \
--disable_tqdm true \
--dataloader_num_workers 1 \
--distributed_dataloader 0 \
--enable_auto_parallel 1 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 1 \
--per_device_eval_batch_size 2 \
--recompute false \
--recompute_use_reentrant true \
--recompute_granularity full \
--pp_recompute_interval 0 \
--bf16 0 \
--fp16_opt_level "O2" \
--amp_custom_black_list "reduce_sum" "c_softmax_with_cross_entropy" \
--amp_custom_white_list "lookup_table" "lookup_table_v2" \
--amp_master_grad false \
--fuse_attention_ffn true \
--fuse_attention_qkv false \
--fuse_sequence_parallel_allreduce false \
--use_flash_attention 0 \
--use_fused_rope false \
--use_fused_rms_norm 0 \
--max_seq_length 4096 \
--sep_parallel_degree 1 \
--sequence_parallel false \
--pipeline_parallel_degree 1 \
--sharding_parallel_degree 1 \
--tensor_parallel_degree 1 \
--virtual_pp_degree 1 \
--pipeline_schedule_mode "VPP" \
--sharding "" \
--to_static 1 \
--num_hidden_layers 2 \
--resume_from_checkpoint "dy_output/llama_hybrid_ckpt_bs2_fp32_DP2-MP1-PP1/checkpoint-3" \
--auto_parallel_resume_form_hybrid_parallel 1 \
>>${log_path}/$FUNCNAME 2>&1
auto_loss=`cat $auto_case_log_dir/workerlog.0 | grep 'global_step: 4' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
auto_ips=-1
auto_mem=-1
echo "auto result: loss=$auto_loss ips=$auto_ips mem=$auto_mem"

check_result $FUNCNAME ${dy_loss} ${auto_loss} ${dy_ips} ${auto_ips} ${dy_mem} ${auto_mem}
echo "=========== $FUNCNAME run end ==========="
}

function llm_gpt_dygraph_auto_bs8_fp32_DP2() {
echo "=========== $FUNCNAME run begin ==========="
export PYTHONPATH=$root_path/:$PYTHONPATH
Expand Down

0 comments on commit a0609e8

Please sign in to comment.