Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【AutoParallel】Support 'master_grad' in Llama in static auto-parallelism #7658

Merged
merged 7 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1174,9 +1174,6 @@
pipeline.micro_batch_size = self.per_device_train_batch_size
pipeline.schedule_mode = self.pipeline_schedule_mode

if self.amp_master_grad:
warnings.warn("`amp_master_grad` is not supported NOW in AutoParallel!")
self.amp_master_grad = False
logger.info(f"PP configs:{strategy.pipeline}, use master_grad: {self.amp_master_grad}")

if self.do_eval:
Expand Down Expand Up @@ -1260,6 +1257,7 @@
amp.enable = True
amp.dtype = "bfloat16" if self.bf16 else "float16"
amp.level = self.fp16_opt_level.lower()
amp.use_master_grad = self.amp_master_grad

Check warning on line 1260 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1260

Added line #L1260 was not covered by tests
amp.init_loss_scaling = self.scale_loss
amp.custom_black_list = self.amp_custom_black_list if self.amp_custom_black_list is not None else []
amp.custom_white_list = self.amp_custom_white_list if self.amp_custom_white_list is not None else []
Expand Down
71 changes: 70 additions & 1 deletion scripts/distribute/ci_case_auto.sh
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ function llama_case_list_auto() {
llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP1
llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2
llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2-VPP2-Sharding2_stage2
llama_static_auto_recompute_bs16_fp16_DP2-MP2-PP2-VPP2-Sharding2_stage2
}

function gpt_case_list_auto_pir() {
Expand Down Expand Up @@ -1168,6 +1169,75 @@ function llama_static_auto_recompute_bs16_fp32_DP2-MP2-PP2-VPP2-Sharding2_stage2
echo "=========== $FUNCNAME run end ==========="
}

function llama_static_auto_recompute_bs16_fp16_DP2-MP2-PP2-VPP2-Sharding2_stage2() {
echo "=========== $FUNCNAME run begin ==========="
export PYTHONPATH=$root_path/:$PYTHONPATH
export FLAGS_call_stack_level=2

task_name="llama_auto_bs16_fp16_dp2mp2pp2vpp2sharding2"
case_out_dir="output/$task_name"
case_log_dir="output/$task_name""_log"
rm -rf $case_out_dir
rm -rf $case_log_dir

python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" --log_dir $case_log_dir run_pretrain_auto.py \
--model_type "llama" \
--model_name_or_path "facebook/llama-7b" \
--tokenizer_name_or_path "facebook/llama-7b" \
--hidden_size 1024 \
--intermediate_size 3072 \
--num_hidden_layers 8 \
--num_attention_heads 32 \
--input_dir "./data" \
--output_dir $case_out_dir \
--split 949,50,1 \
--max_seq_length 2048 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 8 \
--gradient_accumulation_steps 8 \
--use_flash_attention 0 \
--use_fused_rms_norm 0 \
--fp16 1 \
--fp16_opt_level "O2" \
--amp_master_grad 1 \
--scale_loss 1024 \
--tensor_parallel_degree 2 \
--pipeline_parallel_degree 2 \
--virtual_pp_degree 2 \
--pipeline_schedule_mode "VPP" \
--sharding_parallel_degree 2 \
--sharding "stage2" \
--learning_rate 0.0001 \
--min_learning_rate 0.00001 \
--max_steps 10 \
--save_steps 5000 \
--weight_decay 0.01 \
--warmup_ratio 0.01 \
--max_grad_norm 1.0 \
--logging_steps 1 \
--dataloader_num_workers 1 \
--eval_steps 1000 \
--report_to "visualdl" \
--disable_tqdm true \
--continue_training 0 \
--recompute 1 \
--do_train \
--do_eval \
--device "gpu" \
--data_impl "mmap" \
--parallel_mode "auto" \
>>${log_path}/$FUNCNAME 2>&1
loss=`cat $case_log_dir/workerlog.3 | grep 'global_step: 10' | awk -F 'loss: ' '{print $2}' | awk -F ',' '{print $1}'`
ips=-1
mem=-1
echo "result: loss=$loss ips=$ips mem=$mem"
loss_base=10.0859375
ips_base=-1
mem_base=-1
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
echo "=========== $FUNCNAME run end ==========="
}

function llama_dygraph_auto_bs4_fp32_DP2-MP2-PP2() {
echo "=========== $FUNCNAME run begin ==========="
export PYTHONPATH=$root_path/:$PYTHONPATH
Expand Down Expand Up @@ -1233,7 +1303,6 @@ function llama_dygraph_auto_bs4_fp32_DP2-MP2-PP2() {
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
echo "=========== $FUNCNAME run end ==========="
}

############ case end ############

function check_result() {
Expand Down
2 changes: 2 additions & 0 deletions scripts/distribute/run_ci.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@ export case_list=()

target_lists_for_gpt=(
"model_zoo/gpt-3"
"scripts/distribute"
)

target_lists_for_llama=(
"llm/llama/auto_parallel"
"paddlenlp/transformers/llama/modeling_auto.py"
"scripts/distribute"
)

target_path_for_ci_scripts="scripts/distribute"
Expand Down