Skip to content

Commit

Permalink
fix grads acc in llama auto (#7625)
Browse files Browse the repository at this point in the history
  • Loading branch information
haohongxiang authored Dec 12, 2023
1 parent 0a58a1a commit a4ed7ac
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 62 deletions.
4 changes: 2 additions & 2 deletions llm/llama/auto_parallel/run_auto.sh
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,13 @@ python -u -m paddle.distributed.launch \
--weight_decay 0.01 \
--warmup_ratio 0.01 \
--max_grad_norm 1.0 \
--logging_steps 1\
--logging_steps 1 \
--dataloader_num_workers 1 \
--sharding "" \
--eval_steps 1000 \
--report_to "visualdl" \
--disable_tqdm true \
--continue_training 0\
--continue_training 0 \
--recompute 1 \
--do_train \
--do_eval \
Expand Down
48 changes: 19 additions & 29 deletions llm/llama/auto_parallel/run_pretrain_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,23 +410,6 @@ def init_seed(seed: int = 1234, args=None):
paddle.seed(args.seed)


def validate_batch(batch, args):
batches = []
if args.pipeline_parallel_degree > 1 or args.gradient_accumulation_steps == 1:
batches = batch
else:
feed_names = []
split_batches = []
for n, b in batch[0].items():
feed_names.append(n)
split_batches.append(np.split(np.array(b), args.gradient_accumulation_steps, 0))
for i in range(len(split_batches[0])):
micro_batch = [split_batch[i] for split_batch in split_batches]
batches.append(dict(zip(feed_names, micro_batch)))

return batches


def main():
parser = PdArgumentParser((ModelArguments, DataArguments, PreTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
Expand Down Expand Up @@ -562,11 +545,11 @@ def fn(layer):
def loss_func(loss, outputs):
return loss

total_train_batch_size = (
training_args.per_device_train_batch_size
* training_args.gradient_accumulation_steps
* training_args.data_parallel_degree
total_train_batch_size_per_acc_step = (
training_args.per_device_train_batch_size * training_args.data_parallel_degree
)
total_train_batch_size = total_train_batch_size_per_acc_step * training_args.gradient_accumulation_steps

print_config(training_args)

engine = auto.Engine(model, loss_func, optimizer, strategy=training_args.strategy)
Expand All @@ -582,14 +565,14 @@ def loss_func(loss, outputs):
mode="train",
)

dp_degree = training_args.data_parallel_degree
mp_degree = training_args.tensor_parallel_degree
pp_degree = training_args.pipeline_parallel_degree
dp_degree = max(training_args.data_parallel_degree, 1)
mp_degree = max(training_args.tensor_parallel_degree, 1)
pp_degree = max(training_args.pipeline_parallel_degree, 1)
assert dp_degree * mp_degree * pp_degree == dist.get_world_size()

train_dataloader = engine.dataloader(
dataset=train_dataset,
batch_size=total_train_batch_size,
batch_size=total_train_batch_size_per_acc_step if pp_degree == 1 else total_train_batch_size,
steps_per_epoch=training_args.max_steps,
epochs=training_args.num_train_epochs,
collate_fn=data_collator,
Expand All @@ -607,11 +590,16 @@ def loss_func(loss, outputs):
global_step_last_logged = 0
start_time_last_logged = time.time()
tr_loss = float(0)
local_batches = []
for epoch_idx in range(num_train_epochs):
for step, inputs in enumerate(train_dataloader):
batches = validate_batch(inputs, training_args)
local_batches.append(inputs)
if pp_degree == 1 and len(local_batches) < training_args.gradient_accumulation_steps:
continue
elif pp_degree > 1:
local_batches = inputs

for micro_batch in batches:
for micro_batch in local_batches:
outs = engine.run(micro_batch, mode="train")

if "loss" in outs:
Expand All @@ -624,11 +612,13 @@ def loss_func(loss, outputs):

tr_loss += tr_loss_step

local_batches = []

if lr_scheduler is not None:
engine.optimizer._learning_rate.step()

global_step += 1
if (step + 1) % training_args.logging_steps == 0:
if global_step % training_args.logging_steps == 0:
num_steps = global_step - global_step_last_logged
logs = {}
logs["loss"] = round(tr_loss / num_steps, 8)
Expand All @@ -648,7 +638,7 @@ def loss_func(loss, outputs):
start_time_last_logged = time.time()
tr_loss = float(0)

if step >= training_args.max_steps:
if global_step >= training_args.max_steps:
break


Expand Down
8 changes: 6 additions & 2 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1301,15 +1301,19 @@ def data_parallel_rank(self):

@property
def dataset_rank(self):
if self.use_hybrid_parallel or self.use_auto_parallel:
if self.use_hybrid_parallel:
return max(self.sharding_parallel_degree, 1) * self.data_parallel_rank + self.sharding_parallel_rank
elif self.use_auto_parallel:
return self.data_parallel_rank
else:
return paddle.distributed.get_rank()

@property
def dataset_world_size(self):
if self.use_hybrid_parallel or self.use_auto_parallel:
if self.use_hybrid_parallel:
return max(self.sharding_parallel_degree, 1) * max(self.data_parallel_degree, 1)
elif self.use_auto_parallel:
return max(self.data_parallel_degree, 1)
else:
return paddle.distributed.get_world_size()

Expand Down
58 changes: 29 additions & 29 deletions scripts/distribute/ci_case_auto.sh
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ function gpt_case_list_auto() {
}

function llama_case_list_auto() {
llama_auto_recompute_bs1_fp32_DP1-MP1-PP1
llama_auto_recompute_bs2_fp32_DP2-MP1-PP1
llama_auto_recompute_bs2_fp32_DP2-MP2-PP1
llama_auto_recompute_bs8_fp32_DP2-MP2-PP2
llama_auto_recompute_bs8_fp32_DP1-MP1-PP1
llama_auto_recompute_bs16_fp32_DP2-MP1-PP1
llama_auto_recompute_bs16_fp32_DP2-MP2-PP1
llama_auto_recompute_bs16_fp32_DP2-MP2-PP2
}

function case_list_auto_pir() {
Expand Down Expand Up @@ -833,13 +833,13 @@ function gpt_auto_sp_acc_check() {
echo "=========== $FUNCNAME run end ==========="
}

function llama_auto_recompute_bs1_fp32_DP1-MP1-PP1() {
function llama_auto_recompute_bs8_fp32_DP1-MP1-PP1() {
echo "=========== $FUNCNAME run begin ==========="
export PYTHONPATH=$root_path/:$PYTHONPATH
export FLAGS_call_stack_level=2
export SOT_LOG_LEVEL=4

task_name="llama_auto_bs1_dp1mp1pp1"
task_name="llama_auto_bs8_dp1mp1pp1"
case_out_dir="output/$task_name"
case_log_dir="output/$task_name""_log"
rm -rf $case_out_dir
Expand All @@ -859,7 +859,7 @@ function llama_auto_recompute_bs1_fp32_DP1-MP1-PP1() {
--max_seq_length 2048 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--gradient_accumulation_steps 8 \
--use_flash_attention 0 \
--use_fused_rms_norm 0 \
--fp16 0 \
Expand All @@ -876,13 +876,13 @@ function llama_auto_recompute_bs1_fp32_DP1-MP1-PP1() {
--weight_decay 0.01 \
--warmup_ratio 0.01 \
--max_grad_norm 1.0 \
--logging_steps 1\
--logging_steps 1 \
--dataloader_num_workers 1 \
--sharding "" \
--eval_steps 1000 \
--report_to "visualdl" \
--disable_tqdm true \
--continue_training 0\
--continue_training 0 \
--recompute 1 \
--do_train \
--do_eval \
Expand All @@ -894,20 +894,20 @@ function llama_auto_recompute_bs1_fp32_DP1-MP1-PP1() {
ips=-1
mem=-1
echo "result: loss=$loss ips=$ips mem=$mem"
loss_base=9.71193314
loss_base=9.52110565
ips_base=-1
mem_base=-1
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
echo "=========== $FUNCNAME run end ==========="
}

function llama_auto_recompute_bs2_fp32_DP2-MP1-PP1() {
function llama_auto_recompute_bs16_fp32_DP2-MP1-PP1() {
echo "=========== $FUNCNAME run begin ==========="
export PYTHONPATH=$root_path/:$PYTHONPATH
export FLAGS_call_stack_level=2
export SOT_LOG_LEVEL=4

task_name="llama_auto_bs2_dp2mp1pp1"
task_name="llama_auto_bs16_dp2mp1pp1"
case_out_dir="output/$task_name"
case_log_dir="output/$task_name""_log"
rm -rf $case_out_dir
Expand All @@ -927,7 +927,7 @@ function llama_auto_recompute_bs2_fp32_DP2-MP1-PP1() {
--max_seq_length 2048 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--gradient_accumulation_steps 8 \
--use_flash_attention 0 \
--use_fused_rms_norm 0 \
--fp16 0 \
Expand All @@ -944,13 +944,13 @@ function llama_auto_recompute_bs2_fp32_DP2-MP1-PP1() {
--weight_decay 0.01 \
--warmup_ratio 0.01 \
--max_grad_norm 1.0 \
--logging_steps 1\
--logging_steps 1 \
--dataloader_num_workers 1 \
--sharding "" \
--eval_steps 1000 \
--report_to "visualdl" \
--disable_tqdm true \
--continue_training 0\
--continue_training 0 \
--recompute 1 \
--do_train \
--do_eval \
Expand All @@ -962,20 +962,20 @@ function llama_auto_recompute_bs2_fp32_DP2-MP1-PP1() {
ips=-1
mem=-1
echo "result: loss=$loss ips=$ips mem=$mem"
loss_base=9.57837963
loss_base=9.41858447
ips_base=-1
mem_base=-1
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
echo "=========== $FUNCNAME run end ==========="
}

function llama_auto_recompute_bs2_fp32_DP2-MP2-PP1() {
function llama_auto_recompute_bs16_fp32_DP2-MP2-PP1() {
echo "=========== $FUNCNAME run begin ==========="
export PYTHONPATH=$root_path/:$PYTHONPATH
export FLAGS_call_stack_level=2
export SOT_LOG_LEVEL=4

task_name="llama_auto_bs2_dp2mp2pp1"
task_name="llama_auto_bs16_dp2mp2pp1"
case_out_dir="output/$task_name"
case_log_dir="output/$task_name""_log"
rm -rf $case_out_dir
Expand All @@ -995,7 +995,7 @@ function llama_auto_recompute_bs2_fp32_DP2-MP2-PP1() {
--max_seq_length 2048 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--gradient_accumulation_steps 8 \
--use_flash_attention 0 \
--use_fused_rms_norm 0 \
--fp16 0 \
Expand All @@ -1012,13 +1012,13 @@ function llama_auto_recompute_bs2_fp32_DP2-MP2-PP1() {
--weight_decay 0.01 \
--warmup_ratio 0.01 \
--max_grad_norm 1.0 \
--logging_steps 1\
--logging_steps 1 \
--dataloader_num_workers 1 \
--sharding "" \
--eval_steps 1000 \
--report_to "visualdl" \
--disable_tqdm true \
--continue_training 0\
--continue_training 0 \
--recompute 1 \
--do_train \
--do_eval \
Expand All @@ -1030,20 +1030,20 @@ function llama_auto_recompute_bs2_fp32_DP2-MP2-PP1() {
ips=-1
mem=-1
echo "result: loss=$loss ips=$ips mem=$mem"
loss_base=9.6846962
loss_base=9.53447247
ips_base=-1
mem_base=-1
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
echo "=========== $FUNCNAME run end ==========="
}

function llama_auto_recompute_bs8_fp32_DP2-MP2-PP2() {
function llama_auto_recompute_bs16_fp32_DP2-MP2-PP2() {
echo "=========== $FUNCNAME run begin ==========="
export PYTHONPATH=$root_path/:$PYTHONPATH
export FLAGS_call_stack_level=2
export SOT_LOG_LEVEL=4

task_name="llama_auto_bs2_dp2mp2pp1"
task_name="llama_auto_bs16_dp2mp2pp1"
case_out_dir="output/$task_name"
case_log_dir="output/$task_name""_log"
rm -rf $case_out_dir
Expand All @@ -1062,8 +1062,8 @@ function llama_auto_recompute_bs8_fp32_DP2-MP2-PP2() {
--split 949,50,1 \
--max_seq_length 2048 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 4 \
--per_device_eval_batch_size 8 \
--gradient_accumulation_steps 8 \
--use_flash_attention 0 \
--use_fused_rms_norm 0 \
--fp16 0 \
Expand All @@ -1080,13 +1080,13 @@ function llama_auto_recompute_bs8_fp32_DP2-MP2-PP2() {
--weight_decay 0.01 \
--warmup_ratio 0.01 \
--max_grad_norm 1.0 \
--logging_steps 1\
--logging_steps 1 \
--dataloader_num_workers 1 \
--sharding "" \
--eval_steps 1000 \
--report_to "visualdl" \
--disable_tqdm true \
--continue_training 0\
--continue_training 0 \
--recompute 1 \
--do_train \
--do_eval \
Expand All @@ -1098,7 +1098,7 @@ function llama_auto_recompute_bs8_fp32_DP2-MP2-PP2() {
ips=-1
mem=-1
echo "result: loss=$loss ips=$ips mem=$mem"
loss_base=9.59060478
loss_base=9.52331257
ips_base=-1
mem_base=-1
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
Expand Down

0 comments on commit a4ed7ac

Please sign in to comment.