Skip to content

Commit 72094d4

Browse files
authored
enable dcu ci (#3402)
1 parent 73d60fe commit 72094d4

File tree

11 files changed

+295
-5
lines changed

11 files changed

+295
-5
lines changed

custom_ops/gpu_ops/get_padding_offset.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,11 @@ __global__ void GetPaddingOffsetKernel(int *batch_id_per_token,
4646
const int ti = threadIdx.x;
4747
int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1];
4848
for (int i = ti; i < seq_lens[bi]; i += blockDim.x) {
49+
#ifdef PADDLE_WITH_HIP
50+
batch_id_per_token[bi * max_seq_len - cum_offset + i] = cum_offset;
51+
#else
4952
batch_id_per_token[bi * max_seq_len - cum_offset + i] = bi;
53+
#endif
5054
}
5155
if (ti == 0) {
5256
cum_offsets_out[bi] = cum_offset;

fastdeploy/model_executor/forward_meta.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,3 +197,13 @@ class XPUForwardMeta(ForwardMeta):
197197
dec_batch: Optional[paddle.Tensor] = None
198198
#
199199
total_enc_len: Optional[paddle.Tensor] = None
200+
201+
202+
@dataclass
203+
class DCUForwardMeta(ForwardMeta):
204+
"""
205+
DCUForwardMeta is used to store the global meta information of the forward, and some DCU specific meta info.
206+
"""
207+
208+
# Accumulated offset
209+
cum_offsets: Optional[paddle.Tensor] = None

fastdeploy/model_executor/layers/attention/block_multihead_attn_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def forward_mixed(
154154
forward_meta.seq_lens_encoder,
155155
forward_meta.seq_lens_decoder,
156156
forward_meta.seq_lens_this_time,
157-
forward_meta.padding_offset,
157+
forward_meta.batch_id_per_token,
158158
forward_meta.cum_offsets,
159159
forward_meta.cu_seqlens_q,
160160
forward_meta.cu_seqlens_k,

fastdeploy/model_executor/layers/backends/dcu/fused_moe_triton_backends.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,19 +101,19 @@ def apply(
101101
self,
102102
layer: nn.Layer,
103103
x: paddle.Tensor,
104-
gate_out: paddle.Tensor,
104+
gate: nn.Layer,
105105
) -> paddle.Tensor:
106106
"""
107107
Triton compute Fused MoE.
108108
"""
109+
gate_out = gate(x.cast("float32"))
109110
token_num = x.shape[0]
110111
top_k = layer.top_k
111112
num_local_experts = layer.num_local_experts
112113
top_k = layer.top_k
113114
moe_intermediate_size = layer.moe_intermediate_size
114115
hidden_size = layer.hidden_size
115116

116-
gate_out = paddle.matmul(x.cast("float32"), layer.gate_weight)
117117
scores = paddle.nn.functional.softmax(gate_out, axis=-1)
118118
scores += layer.gate_correction_bias
119119
topk_weights, topk_ids = paddle.topk(scores, k=top_k, axis=-1, sorted=False)

fastdeploy/model_executor/layers/backends/dcu/top_p_sampling.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ def native_top_p_sampling(probs: paddle.Tensor, top_p: paddle.Tensor) -> tuple[p
2121
sorted_indices = paddle.argsort(probs, descending=True)
2222
sorted_probs = paddle.sort(probs, descending=True)
2323
cumulative_probs = paddle.cumsum(sorted_probs, axis=-1)
24+
if probs.shape[0] != top_p.shape[0]:
25+
top_p = paddle.slice(top_p, [0], [0], [probs.shape[0]])
2426
sorted_indices_to_remove = cumulative_probs > top_p
2527
sorted_indices_to_remove = paddle.cast(sorted_indices_to_remove, dtype="int64")
2628
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()

fastdeploy/model_executor/pre_and_post_process.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def post_process_normal(
218218
model_output.stop_flags,
219219
)
220220

221-
if current_platform.is_cuda() or current_platform.is_iluvatar():
221+
if current_platform.is_cuda() or current_platform.is_iluvatar() or current_platform.is_dcu():
222222
set_stop_value_multi_ends(
223223
sampler_output.sampled_token_ids,
224224
model_output.stop_flags,
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
"""
2+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License"
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""
16+
17+
import paddle
18+
19+
from fastdeploy.config import FDConfig
20+
from fastdeploy.model_executor.forward_meta import DCUForwardMeta
21+
from fastdeploy.worker.gpu_model_runner import GPUModelRunner
22+
23+
24+
class DCUModelRunner(GPUModelRunner):
25+
def __init__(
26+
self,
27+
fd_config: FDConfig,
28+
device: str, # logic device
29+
device_id: int, # physical device id
30+
rank: int,
31+
local_rank: int,
32+
):
33+
super(DCUModelRunner, self).__init__(
34+
fd_config=fd_config, device=device, device_id=device_id, rank=rank, local_rank=local_rank
35+
)
36+
37+
def initialize_forward_meta(self):
38+
"""
39+
Initialize forward meta and attention meta data
40+
"""
41+
# Initialize forward meta
42+
self.forward_meta = DCUForwardMeta(
43+
input_ids=self.share_inputs["input_ids"],
44+
ids_remove_padding=self.share_inputs["ids_remove_padding"],
45+
rotary_embs=self.share_inputs["rope_emb"],
46+
attn_backend=self.attn_backends[0],
47+
decoder_batch_ids=self.share_inputs["decoder_batch_ids"],
48+
decoder_tile_ids_per_batch=self.share_inputs["decoder_tile_ids_per_batch"],
49+
decoder_num_blocks_cpu=self.share_inputs["decoder_num_blocks_cpu"],
50+
max_len_tensor_cpu=self.share_inputs["max_len_tensor_cpu"],
51+
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
52+
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
53+
seq_lens_this_time=self.share_inputs["seq_lens_this_time"],
54+
batch_id_per_token=self.share_inputs["batch_id_per_token"],
55+
cum_offsets=self.share_inputs["cum_offsets"],
56+
cu_seqlens_q=self.share_inputs["cu_seqlens_q"],
57+
cu_seqlens_k=self.share_inputs["cu_seqlens_k"],
58+
block_tables=self.share_inputs["block_tables"],
59+
caches=self.share_inputs["caches"],
60+
)
61+
62+
# Update Batch type for cuda graph
63+
only_decode_batch = True
64+
prefill_exists = None
65+
# mix ep in single node
66+
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed":
67+
only_decode_batch_list = []
68+
prefill_exists = self.exist_prefill()
69+
paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists)
70+
only_decode_batch = all(only_decode_batch_list)
71+
self.fd_config.parallel_config.moe_phase.phase = "decode" if only_decode_batch else "prefill"
72+
73+
self.forward_meta.step_use_cudagraph = (
74+
self.use_cudagraph
75+
and only_decode_batch
76+
and not (prefill_exists if prefill_exists is not None else self.exist_prefill())
77+
)
78+
79+
# Initialzie attention meta data
80+
for attn_backend in self.attn_backends:
81+
attn_backend.init_attention_metadata(self.forward_meta)

fastdeploy/worker/dcu_worker.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@
1414
# limitations under the License.
1515
"""
1616

17+
import gc
1718
import time
1819

1920
import paddle
2021

2122
from fastdeploy.config import FDConfig
22-
from fastdeploy.utils import get_logger
23+
from fastdeploy.utils import get_logger, set_random_seed
24+
from fastdeploy.worker.dcu_model_runner import DCUModelRunner
2325
from fastdeploy.worker.gpu_worker import GpuWorker
2426

2527
logger = get_logger("dcu_worker", "dcu_worker.log")
@@ -41,6 +43,41 @@ def __init__(
4143
)
4244
pass
4345

46+
def init_device(self):
47+
"""
48+
Initialize device and construct model runner
49+
"""
50+
self.max_chips_per_node = 8
51+
if self.device_config.device_type == "cuda" and paddle.device.is_compiled_with_cuda():
52+
# Set evironment variable
53+
self.device_ids = self.parallel_config.device_ids.split(",")
54+
self.device = f"gpu:{self.local_rank % self.max_chips_per_node}"
55+
paddle.device.set_device(self.device)
56+
paddle.set_default_dtype(self.parallel_config.dtype)
57+
58+
gc.collect()
59+
paddle.device.cuda.empty_cache()
60+
if (
61+
self.parallel_config.enable_custom_all_reduce
62+
and self.parallel_config.tensor_parallel_size > 1
63+
and paddle.is_compiled_with_cuda()
64+
):
65+
from fastdeploy.distributed.communication import use_custom_allreduce
66+
67+
use_custom_allreduce()
68+
else:
69+
raise RuntimeError(f"Not support device type: {self.device_config.device}")
70+
71+
set_random_seed(self.fd_config.model_config.seed)
72+
# Construct model runner
73+
self.model_runner: DCUModelRunner = DCUModelRunner(
74+
fd_config=self.fd_config,
75+
device=self.device,
76+
device_id=self.device_ids[self.local_rank % self.max_chips_per_node],
77+
rank=self.rank,
78+
local_rank=self.local_rank,
79+
)
80+
4481
def determine_available_memory(self) -> int:
4582
"""
4683
Profiles the peak memory usage of the model to determine how much

fastdeploy/worker/gpu_model_runner.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@
4646
if current_platform.is_iluvatar():
4747
from fastdeploy.model_executor.ops.iluvatar import set_value_by_flags_and_idx
4848

49+
recover_decode_task = None
50+
share_external_data = None
51+
elif current_platform.is_dcu():
52+
from fastdeploy.model_executor.ops.gpu import set_value_by_flags_and_idx
53+
4954
recover_decode_task = None
5055
share_external_data = None
5156
else:

scripts/run_ci_dcu.sh

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
#!/bin/bash
2+
DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
3+
echo "$DIR"
4+
5+
function stop_processes() {
6+
ps -efww | grep -E 'api_server' | grep -v grep | awk '{print $2}' | xargs kill -9 || true
7+
ps -efww | grep -E '8188' | grep -v grep | awk '{print $2}' | xargs kill -9 || true
8+
lsof -t -i :8188 | xargs kill -9 || true
9+
}
10+
11+
echo "Clean up processes..."
12+
stop_processes
13+
echo "Clean up completed."
14+
15+
export model_path=${MODEL_PATH}/paddle/ERNIE-4.5-21B-A3B-Paddle
16+
17+
python -m pip install paddlepaddle_dcu==3.1.1 -i https://www.paddlepaddle.org.cn/packages/stable/dcu/
18+
python -m pip install https://paddle-whl.bj.bcebos.com/stable/dcu/triton/triton-3.0.0%2Bdas.opt4.0da70a2.dtk2504-cp310-cp310-manylinux_2_28_x86_64.whl
19+
20+
python -m pip install git+https://github.com/zhoutianzi666/UseTritonInPaddle.git
21+
python -c "import use_triton_in_paddle; use_triton_in_paddle.make_triton_compatible_with_paddle()"
22+
23+
echo "pip install requirements_dcu"
24+
python -m pip install -r requirements_dcu.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
25+
26+
echo "build whl"
27+
bash build.sh || exit 1
28+
29+
unset http_proxy
30+
unset https_proxy
31+
unset no_proxy
32+
33+
34+
rm -rf log/*
35+
rm -f core*
36+
37+
# Empty the message queue
38+
ipcrm --all=msg
39+
echo "Start server..."
40+
export FD_ATTENTION_BACKEND="BLOCK_ATTN"
41+
python -m fastdeploy.entrypoints.openai.api_server \
42+
--model ${model_path} \
43+
--port 8188 \
44+
--tensor-parallel-size 4 \
45+
--gpu-memory-utilization 0.8 \
46+
--quantization wint8 > server.log 2>&1 &
47+
48+
echo "Waiting 90 seconds..."
49+
sleep 90
50+
51+
if grep -q "Failed to launch worker processes" server.log; then
52+
echo "Failed to launch worker processes..."
53+
stop_processes
54+
cat server.log
55+
cat log/workerlog.0
56+
exit 1
57+
fi
58+
59+
if grep -q "Traceback (most recent call last):" server.log; then
60+
echo "Some errors occurred..."
61+
stop_processes
62+
cat server.log
63+
cat log/workerlog.0
64+
exit 1
65+
fi
66+
67+
# Health check
68+
TIMEOUT=$((5 * 60))
69+
INTERVAL=10 # Check interval (seconds)
70+
ENDPOINT="http://0.0.0.0:8188/health"
71+
START_TIME=$(date +%s) # Record the start timestamp
72+
echo "Start the server health check, maximum waiting time: ${TIMEOUT} seconds..."
73+
while true; do
74+
# Used to calculate the time cost
75+
CURRENT_TIME=$(date +%s)
76+
ELAPSED=$((CURRENT_TIME - START_TIME))
77+
78+
# Timeout
79+
if [ $ELAPSED -ge $TIMEOUT ]; then
80+
echo -e "\nServer start timeout: After $((TIMEOUT/60)) minutes, the service still doesn't start!"
81+
cat server.log
82+
cat log/workerlog.0
83+
exit 1
84+
fi
85+
86+
HTTP_CODE=$(curl -s -o /dev/null -w "%{http_code}" -m 2 "$ENDPOINT" || true)
87+
88+
if [ "$HTTP_CODE" = "200" ]; then
89+
echo -e "\nThe server was successfully launched! Totally takes $((ELAPSED+90)) seconds."
90+
break
91+
else
92+
sleep $INTERVAL
93+
fi
94+
done
95+
96+
cat server.log
97+
echo -e "\n"
98+
99+
echo "Start inference..."
100+
python test/ci_use/DCU/run_ernie.py
101+
exit_code=$?
102+
echo "exit_code is ${exit_code}.\n"
103+
104+
echo "Stop server..."
105+
stop_processes
106+
echo "Stop server done."
107+
108+
if [ ${exit_code} -ne 0 ]; then
109+
echo "Exit with error, please refer to log/workerlog.0"
110+
cat log/workerlog.0
111+
exit 1
112+
fi

0 commit comments

Comments
 (0)