Skip to content

Commit

Permalink
Support XPU for auto-paralllel LLaMa
Browse files Browse the repository at this point in the history
  • Loading branch information
From00 committed Jan 20, 2025
1 parent 13053a7 commit 859900e
Show file tree
Hide file tree
Showing 5 changed files with 203 additions and 4 deletions.
10 changes: 10 additions & 0 deletions llm/auto_parallel/llama/run_pretrain_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
check_data_split,
print_rank_0,
)
from paddlenlp.utils.tools import get_env_device
from paddlenlp.trainer.utils.doc import add_start_docstrings


Expand Down Expand Up @@ -544,6 +545,15 @@ def main():
pipeline = training_args.strategy.pipeline
pipeline.vpp_degree = config.virtual_pp_degree
pipeline.vpp_seg_method = training_args.virtual_pipeline_seg_method
if get_env_device() == "xpu" and training_args.gradient_accumulation_steps > 1:
try:
from paddle_xpu.layers.nn.linear import LinearConfig # noqa: F401

LinearConfig.enable_accumulate_steps_opt()
LinearConfig.set_accumulate_steps(training_args.gradient_accumulation_steps)
except ImportError:
# It's OK, not use accumulate_steps optimization
pass

print("Final pre-training config:", config)

Expand Down
3 changes: 1 addition & 2 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1632,13 +1632,12 @@ def is_segment_parallel_supported():
"enable_mp_async_allreduce", # allreduce_matmul_grad_overlapping in auto_parallel
"enable_delay_scale_loss",
"replace_with_c_embedding",
# "enable_mp_skip_c_identity",
# "enable_mp_fused_linear_param_grad_add",
"replace_with_parallel_cross_entropy",
]:
raise ValueError(
f"Found unknown tensor parallell config {x}, "
f"accept config is enable_mp_async_allreduce, replace_with_c_embedding, enable_mp_skip_c_identity and enable_mp_fused_linear_param_grad_add"
f"accept config is enable_mp_async_allreduce, replace_with_c_embedding, and enable_mp_fused_linear_param_grad_add"
)
try:
if "enable_mp_async_allreduce" in mp_config:
Expand Down
20 changes: 18 additions & 2 deletions paddlenlp/transformers/llama/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def swiglu(x, y=None):
CausalLMOutputWithCrossAttentions,
)
from paddlenlp.transformers.model_utils import PretrainedModel, register_base_model
from paddlenlp.utils.tools import get_env_device

from .configuration import (
LLAMA_PRETRAINED_INIT_CONFIGURATION,
Expand Down Expand Up @@ -308,7 +309,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp:
self.ipp = ipp

self.use_fused_rope = config.use_fused_rope
if self.use_fused_rope:
if self.use_fused_rope and get_env_device() not in ["npu", "mlu", "xpu", "gcu", "intel_hpu"]:

Check warning on line 312 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L312

Added line #L312 was not covered by tests
if "gpu" not in paddle.device.get_device() or fused_rotary_position_embedding is None:
warnings.warn(
"Enable fuse rope in the config, but fuse rope is not available. "
Expand Down Expand Up @@ -935,7 +936,22 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values
else:
expanded_attn_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
# Convert bool attention_mask to float attention mask, which will be added to attention_scores later
expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min).astype(dtype)
if get_env_device() in ["npu", "mlu", "intel_hpu"]:
x = paddle.to_tensor(0.0, dtype="float32")
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float32")
expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), x, y).astype(dtype)
elif get_env_device() == "xpu":
x = paddle.to_tensor(0.0, dtype="float32")
y = paddle.to_tensor(-1.7005809656952787e38, dtype="float32")
expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), x, y)
elif get_env_device() == "gcu":
min_val = paddle.finfo(dtype).min
x = paddle.to_tensor(0.0, dtype=dtype)
y = paddle.to_tensor(min_val, dtype=dtype)
expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), x, y).astype(dtype)

Check warning on line 951 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L939-L951

Added lines #L939 - L951 were not covered by tests
else:
expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), 0.0, paddle.finfo(dtype).min)
expanded_attn_mask = expanded_attn_mask.astype(dtype)

Check warning on line 954 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L953-L954

Added lines #L953 - L954 were not covered by tests
return expanded_attn_mask

def forward(
Expand Down
87 changes: 87 additions & 0 deletions run_llama2_13b_4k_auto.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#!/bin/bash
cd llm
task_name_or_path="llama2-13b-4k"

#export XPUAPI_DEBUG=0x1
#export XPURT_DISPATCH_MODE=PROFILING
export XBLAS_FC_HBM_VERSION=40

# PaddlePaddle
export FLAGS_use_stride_kernel="0"
export XPU_PADDLE_L3_SIZE=98566144 # 94 MB
export XPU_CDNN_CLUSTER_PARALLEL=1
export XPU_CDNN_CLUSTER_PARALLEL_STREAM_NUMBER=2

# PDC
unset PADDLE_ELASTIC_JOB_ID
unset PADDLE_TRAINER_ENDPOINTS
unset DISTRIBUTED_TRAINER_ENDPOINTS
unset FLAGS_START_PORT
unset PADDLE_ELASTIC_TIMEOUT
unset PADDLE_TRAINERS_NUM

# BKCL
# export BKCL_DEBUG=1
# Multi-computer RDMA
#export BKCL_ENABLE_XDR=1
#export BKCL_RDMA_FORCE_TREE=1
#export BKCL_TREE_THRESHOLD=0
#export BKCL_RDMA_NICS=xgbe1,xgbe1,xgbe2,xgbe2,xgbe3,xgbe3,xgbe4,xgbe4
#export BKCL_SOCKET_IFNAME=xgbe0
#export BKCL_FORCE_L3_RDMA=0
echo "bkcl version:"
strings ${bkcl_location}/libbkcl.so | grep COM

export CUDA_DEVICE_MAX_CONNECTIONS=8

export GLOG_v=10

timestamp=$(date +%Y%m%d%H%M%S)
echo $timestamp
PYTHONPATH=../:$PYTHONPATH \
python -u -m paddle.distributed.launch \
--xpus "0,1,2,3,4,5,6,7" \
--log_dir "output/$task_name_or_path/$timestamp""_log" \
auto_parallel/llama/run_pretrain_auto.py \
--model_name_or_path "meta-llama/Llama-2-13b" \
--tokenizer_name_or_path "meta-llama/Llama-2-13b" \
--input_dir "./data" \
--output_dir "output/$task_name_or_path/$timestamp" \
--split 949,50,1 \
--max_seq_length 4096 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--use_flash_attention 1 \
--use_fused_rope 1 \
--fuse_attention_ffn 1 \
--fuse_attention_qkv 1 \
--use_fused_rms_norm 1 \
--num_hidden_layers 40 \
--bf16 \
--fp16_opt_level "O2" \
--amp_master_grad true \
--scale_loss 1024 \
--learning_rate 0.00003 \
--min_learning_rate 0.000005 \
--lr_scheduler_type "cosine" \
--max_steps 100000 \
--save_steps 100000 \
--weight_decay 0.01 \
--warmup_ratio 0.01 \
--max_grad_norm 1.0 \
--logging_steps 1 \
--sequence_parallel 0 \
--dataloader_num_workers 4 \
--pipeline_parallel_degree 2 \
--tensor_parallel_degree 2 \
--gradient_accumulation_steps 32 \
--sharding "stage1" \
--eval_steps 1000 \
--report_to "visualdl" \
--disable_tqdm true \
--continue_training 0 \
--recompute 0 \
--do_train \
--seed 1026 \
--device "xpu" \
--enable_auto_parallel 1
87 changes: 87 additions & 0 deletions run_llama2_13b_4k_pp2.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#!/bin/bash
cd llm
task_name_or_path="llama2-13b-4k"

#export XPUAPI_DEBUG=0x1
#export XPURT_DISPATCH_MODE=PROFILING
export XBLAS_FC_HBM_VERSION=40

# PaddlePaddle
export FLAGS_use_stride_kernel="0"
export XPU_PADDLE_L3_SIZE=98566144 # 94 MB
export XPU_CDNN_CLUSTER_PARALLEL=1
export XPU_CDNN_CLUSTER_PARALLEL_STREAM_NUMBER=2

# PDC
unset PADDLE_ELASTIC_JOB_ID
unset PADDLE_TRAINER_ENDPOINTS
unset DISTRIBUTED_TRAINER_ENDPOINTS
unset FLAGS_START_PORT
unset PADDLE_ELASTIC_TIMEOUT
unset PADDLE_TRAINERS_NUM

# BKCL
# export BKCL_DEBUG=1
# Multi-computer RDMA
#export BKCL_ENABLE_XDR=1
#export BKCL_RDMA_FORCE_TREE=1
#export BKCL_TREE_THRESHOLD=0
#export BKCL_RDMA_NICS=xgbe1,xgbe1,xgbe2,xgbe2,xgbe3,xgbe3,xgbe4,xgbe4
#export BKCL_SOCKET_IFNAME=xgbe0
#export BKCL_FORCE_L3_RDMA=0
echo "bkcl version:"
strings ${bkcl_location}/libbkcl.so | grep COM

export CUDA_DEVICE_MAX_CONNECTIONS=8

timestamp=$(date +%Y%m%d%H%M%S)
echo $timestamp
PYTHONPATH=../:$PYTHONPATH \
python -u -m paddle.distributed.launch \
--xpus "0,1,2,3,4,5,6,7" \
--log_dir "output/$task_name_or_path/$timestamp""_log" \
run_pretrain.py \
--model_name_or_path "meta-llama/Llama-2-13b" \
--tokenizer_name_or_path "meta-llama/Llama-2-13b" \
--input_dir "./data" \
--output_dir "output/$task_name_or_path/$timestamp" \
--split 949,50,1 \
--max_seq_length 4096 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--use_flash_attention 1 \
--use_fused_rope 1 \
--fuse_attention_ffn 1 \
--fuse_attention_qkv 1 \
--use_fused_rms_norm 1 \
--num_hidden_layers 40 \
--bf16 \
--fp16_opt_level "O2" \
--scale_loss 1024 \
--learning_rate 0.00003 \
--min_learning_rate 0.000005 \
--lr_scheduler_type "cosine" \
--max_steps 100000 \
--save_steps 100000 \
--weight_decay 0.01 \
--warmup_ratio 0.01 \
--max_grad_norm 1.0 \
--logging_steps 1 \
--sequence_parallel 0 \
--dataloader_num_workers 4 \
--pipeline_parallel_degree 2 \
--pipeline_parallel_config "disable_partial_send_recv" \
--tensor_parallel_degree 2 \
--tensor_parallel_config "enable_mp_async_allreduce,enable_mp_skip_c_identity" \
--gradient_accumulation_steps 32 \
--sharding "stage1" \
--sharding_parallel_config "split_param" \
--eval_steps 1000 \
--report_to "visualdl" \
--disable_tqdm true \
--continue_training 0 \
--recompute 0 \
--do_train \
--seed 1026 \
--device "xpu" \
--amp_master_grad true

0 comments on commit 859900e

Please sign in to comment.