Skip to content

Commit 43baba6

Browse files
yuan-luoluoyuan.luo
andauthored
[EP] Add cuda kernel for moe_ep_post_reorder (#6837)
Co-authored-by: luoyuan.luo <[email protected]>
1 parent 0166403 commit 43baba6

File tree

7 files changed

+377
-4
lines changed

7 files changed

+377
-4
lines changed
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import torch
2+
import triton
3+
from sgl_kernel import ep_moe_post_reorder
4+
5+
from sglang.srt.layers.moe.ep_moe.kernels import post_reorder_triton_kernel
6+
7+
batch_sizes = [64, 128, 256, 512, 640, 768, 1024, 2048, 4096]
8+
configs = [(bs,) for bs in batch_sizes]
9+
10+
11+
@triton.testing.perf_report(
12+
triton.testing.Benchmark(
13+
x_names=["batch_size"],
14+
x_vals=[list(_) for _ in configs],
15+
line_arg="provider",
16+
line_vals=["cuda", "triton"],
17+
line_names=["CUDA Kernel", "Triton Kernel"],
18+
styles=[("green", "-"), ("orange", "-")],
19+
ylabel="us",
20+
plot_name="ep-moe-post-reorder-performance",
21+
args={},
22+
)
23+
)
24+
def benchmark(batch_size, provider):
25+
dtype = torch.bfloat16
26+
device = torch.device("cuda")
27+
hidden_size, topk, start_expert_id, end_expert_id, block_size = 4096, 8, 0, 255, 512
28+
29+
def alloc_tensors():
30+
down_output = torch.randn(
31+
batch_size * topk, hidden_size, dtype=dtype, device=device
32+
)
33+
output = torch.zeros(batch_size, hidden_size, dtype=dtype, device=device)
34+
src2dst = torch.randint(
35+
0, batch_size * topk, (batch_size, topk), dtype=torch.int32, device=device
36+
)
37+
topk_ids = torch.randint(
38+
start_expert_id,
39+
end_expert_id + 1,
40+
(batch_size, topk),
41+
dtype=torch.int32,
42+
device=device,
43+
)
44+
topk_weights = torch.rand(batch_size, topk, dtype=dtype, device=device)
45+
return down_output, output, src2dst, topk_ids, topk_weights
46+
47+
quantiles = [0.5, 0.2, 0.8]
48+
49+
if provider == "cuda":
50+
d_out, out, s2d, tk_ids, tk_weights = alloc_tensors()
51+
52+
def run_cuda():
53+
ep_moe_post_reorder(
54+
d_out,
55+
out,
56+
s2d,
57+
tk_ids,
58+
tk_weights,
59+
start_expert_id,
60+
end_expert_id,
61+
topk,
62+
)
63+
64+
ms, min_ms, max_ms = triton.testing.do_bench(run_cuda, quantiles=quantiles)
65+
66+
elif provider == "triton":
67+
d_out, out, s2d, tk_ids, tk_weights = alloc_tensors()
68+
69+
def run_triton():
70+
post_reorder_triton_kernel[(batch_size,)](
71+
d_out.view(-1),
72+
out.view(-1),
73+
s2d.view(-1),
74+
tk_ids.view(-1),
75+
tk_weights.view(-1),
76+
start_expert_id,
77+
end_expert_id,
78+
topk,
79+
hidden_size,
80+
block_size,
81+
)
82+
83+
ms, min_ms, max_ms = triton.testing.do_bench(run_triton, quantiles=quantiles)
84+
85+
else:
86+
raise ValueError(f"Unknown provider: {provider}")
87+
88+
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
89+
90+
91+
if __name__ == "__main__":
92+
benchmark.run(print_data=True)

sgl-kernel/csrc/common_extension.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,13 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
174174
"(Tensor[])");
175175
m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate);
176176
m.def(
177-
"ep_moe_pre_reorder(Tensor input_ptr, Tensor gateup_input_ptr, Tensor src2dst_ptr, Tensor topk_ids_ptr, Tensor "
178-
"a1_scales_ptr, int start_expert_id, int end_expert_id, int topk, bool use_per_token_if_dynamic) -> ()");
177+
"ep_moe_pre_reorder(Tensor input, Tensor gateup_input, Tensor src2dst, Tensor topk_ids, Tensor "
178+
"a1_scales, int start_expert_id, int end_expert_id, int topk, bool use_per_token_if_dynamic) -> ()");
179179
m.impl("ep_moe_pre_reorder", torch::kCUDA, &ep_moe_pre_reorder);
180+
m.def(
181+
"ep_moe_post_reorder(Tensor down_output, Tensor output, Tensor src2dst, Tensor topk_ids, Tensor "
182+
"topk_weights, int start_expert_id, int end_expert_id, int topk) -> ()");
183+
m.impl("ep_moe_post_reorder", torch::kCUDA, &ep_moe_post_reorder);
180184
m.def(
181185
"fp8_blockwise_scaled_grouped_mm(Tensor output, Tensor a_ptrs, Tensor b_ptrs, Tensor out_ptrs, Tensor "
182186
"a_scales_ptrs, Tensor b_scales_ptrs, Tensor a, Tensor b, Tensor scales_a, Tensor scales_b, Tensor "

sgl-kernel/csrc/moe/ep_moe_reorder_kernel.cu

Lines changed: 83 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,57 @@ __global__ void ep_pre_reorder_cuda_kernel(
6767
}
6868
}
6969

70+
template <typename scalar_t>
71+
__global__ void ep_post_reorder_cuda_kernel(
72+
const scalar_t* __restrict__ down_output_ptr,
73+
scalar_t* __restrict__ output_ptr,
74+
const int* __restrict__ src2dst_ptr,
75+
const int* __restrict__ topk_ids_ptr,
76+
const scalar_t* __restrict__ topk_weights_ptr,
77+
int start_expert_id,
78+
int end_expert_id,
79+
int topk,
80+
int hidden_size) {
81+
const int token_idx = blockIdx.x;
82+
const int tid = threadIdx.x;
83+
84+
const int* token_src2dst = src2dst_ptr + token_idx * topk;
85+
const int* token_topk_ids = topk_ids_ptr + token_idx * topk;
86+
const scalar_t* token_topk_weights = topk_weights_ptr + token_idx * topk;
87+
88+
scalar_t* dst_ptr = output_ptr + static_cast<int64_t>(token_idx) * hidden_size;
89+
90+
constexpr uint32_t vec_size = 16 / sizeof(scalar_t);
91+
using vec_t = flashinfer::vec_t<scalar_t, vec_size>;
92+
93+
const int vec_iters = hidden_size / vec_size;
94+
for (int idx = tid; idx < vec_iters; idx += blockDim.x) {
95+
float acc[vec_size] = {0};
96+
97+
for (int k = 0; k < topk; ++k) {
98+
const int expert_id = token_topk_ids[k];
99+
if (expert_id < start_expert_id || expert_id > end_expert_id) continue;
100+
const int src_row = token_src2dst[k];
101+
const scalar_t* src_ptr = down_output_ptr + static_cast<int64_t>(src_row) * hidden_size;
102+
const float weight = static_cast<float>(token_topk_weights[k]);
103+
104+
vec_t src_vec;
105+
src_vec.cast_load(src_ptr + idx * vec_size);
106+
107+
#pragma unroll
108+
for (uint32_t i = 0; i < vec_size; ++i) {
109+
acc[i] += static_cast<float>(src_vec[i]) * weight;
110+
}
111+
}
112+
vec_t out_vec;
113+
#pragma unroll
114+
for (uint32_t i = 0; i < vec_size; ++i)
115+
out_vec[i] = static_cast<scalar_t>(acc[i]);
116+
117+
out_vec.cast_store(dst_ptr + idx * vec_size);
118+
}
119+
}
120+
70121
void ep_moe_pre_reorder(
71122
torch::Tensor input,
72123
torch::Tensor gateup_input,
@@ -77,8 +128,8 @@ void ep_moe_pre_reorder(
77128
int64_t end_expert_id,
78129
int64_t topk,
79130
bool use_per_token_if_dynamic) {
80-
int total_blocks = input.size(0);
81-
int block_size = 512;
131+
const int total_blocks = input.size(0);
132+
const int block_size = 512;
82133
dim3 grid(total_blocks);
83134
dim3 block(block_size);
84135
int hidden_size = input.size(1);
@@ -98,3 +149,33 @@ void ep_moe_pre_reorder(
98149
return true;
99150
});
100151
}
152+
153+
void ep_moe_post_reorder(
154+
torch::Tensor down_output,
155+
torch::Tensor output,
156+
torch::Tensor src2dst,
157+
torch::Tensor topk_ids,
158+
torch::Tensor topk_weights,
159+
int64_t start_expert_id,
160+
int64_t end_expert_id,
161+
int64_t topk) {
162+
const int total_tokens = output.size(0);
163+
const int block_size = 512;
164+
dim3 grid(total_tokens);
165+
dim3 block(block_size);
166+
const int hidden_size = output.size(1);
167+
168+
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(down_output.scalar_type(), scalar_t, [&] {
169+
ep_post_reorder_cuda_kernel<scalar_t><<<grid, block>>>(
170+
static_cast<scalar_t*>(down_output.data_ptr()),
171+
static_cast<scalar_t*>(output.data_ptr()),
172+
src2dst.data_ptr<int>(),
173+
topk_ids.data_ptr<int>(),
174+
static_cast<scalar_t*>(topk_weights.data_ptr()),
175+
static_cast<int>(start_expert_id),
176+
static_cast<int>(end_expert_id),
177+
static_cast<int>(topk),
178+
hidden_size);
179+
return true;
180+
});
181+
}

sgl-kernel/include/sgl_kernel_ops.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,16 @@ void ep_moe_pre_reorder(
264264
int64_t topk,
265265
bool use_per_token_if_dynamic);
266266

267+
void ep_moe_post_reorder(
268+
torch::Tensor down_output,
269+
torch::Tensor output,
270+
torch::Tensor src2dst,
271+
torch::Tensor topk_ids,
272+
torch::Tensor topk_weights,
273+
int64_t start_expert_id,
274+
int64_t end_expert_id,
275+
int64_t topk);
276+
267277
void shuffle_rows(const torch::Tensor& input_tensor, const torch::Tensor& dst2src_map, torch::Tensor& output_tensor);
268278

269279
void cutlass_fp4_group_mm(

sgl-kernel/python/sgl_kernel/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda
5050
from sgl_kernel.moe import (
5151
cutlass_fp4_group_mm,
52+
ep_moe_post_reorder,
5253
ep_moe_pre_reorder,
5354
fp8_blockwise_scaled_grouped_mm,
5455
moe_align_block_size,

sgl-kernel/python/sgl_kernel/moe.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,28 @@ def ep_moe_pre_reorder(
8888
)
8989

9090

91+
def ep_moe_post_reorder(
92+
down_output,
93+
output,
94+
src2dst,
95+
topk_ids,
96+
topk_weights,
97+
start_expert_id,
98+
end_expert_id,
99+
topk,
100+
):
101+
return torch.ops.sgl_kernel.ep_moe_post_reorder.default(
102+
down_output,
103+
output,
104+
src2dst,
105+
topk_ids,
106+
topk_weights,
107+
start_expert_id,
108+
end_expert_id,
109+
topk,
110+
)
111+
112+
91113
def fp8_blockwise_scaled_grouped_mm(
92114
output,
93115
a_ptrs,

0 commit comments

Comments
 (0)