@@ -67,6 +67,57 @@ __global__ void ep_pre_reorder_cuda_kernel(
6767 }
6868}
6969
70+ template <typename scalar_t >
71+ __global__ void ep_post_reorder_cuda_kernel (
72+ const scalar_t * __restrict__ down_output_ptr,
73+ scalar_t * __restrict__ output_ptr,
74+ const int * __restrict__ src2dst_ptr,
75+ const int * __restrict__ topk_ids_ptr,
76+ const scalar_t * __restrict__ topk_weights_ptr,
77+ int start_expert_id,
78+ int end_expert_id,
79+ int topk,
80+ int hidden_size) {
81+ const int token_idx = blockIdx .x ;
82+ const int tid = threadIdx .x ;
83+
84+ const int * token_src2dst = src2dst_ptr + token_idx * topk;
85+ const int * token_topk_ids = topk_ids_ptr + token_idx * topk;
86+ const scalar_t * token_topk_weights = topk_weights_ptr + token_idx * topk;
87+
88+ scalar_t * dst_ptr = output_ptr + static_cast <int64_t >(token_idx) * hidden_size;
89+
90+ constexpr uint32_t vec_size = 16 / sizeof (scalar_t );
91+ using vec_t = flashinfer::vec_t <scalar_t , vec_size>;
92+
93+ const int vec_iters = hidden_size / vec_size;
94+ for (int idx = tid; idx < vec_iters; idx += blockDim .x ) {
95+ float acc[vec_size] = {0 };
96+
97+ for (int k = 0 ; k < topk; ++k) {
98+ const int expert_id = token_topk_ids[k];
99+ if (expert_id < start_expert_id || expert_id > end_expert_id) continue ;
100+ const int src_row = token_src2dst[k];
101+ const scalar_t * src_ptr = down_output_ptr + static_cast <int64_t >(src_row) * hidden_size;
102+ const float weight = static_cast <float >(token_topk_weights[k]);
103+
104+ vec_t src_vec;
105+ src_vec.cast_load (src_ptr + idx * vec_size);
106+
107+ #pragma unroll
108+ for (uint32_t i = 0 ; i < vec_size; ++i) {
109+ acc[i] += static_cast <float >(src_vec[i]) * weight;
110+ }
111+ }
112+ vec_t out_vec;
113+ #pragma unroll
114+ for (uint32_t i = 0 ; i < vec_size; ++i)
115+ out_vec[i] = static_cast <scalar_t >(acc[i]);
116+
117+ out_vec.cast_store (dst_ptr + idx * vec_size);
118+ }
119+ }
120+
70121void ep_moe_pre_reorder (
71122 torch::Tensor input,
72123 torch::Tensor gateup_input,
@@ -77,8 +128,8 @@ void ep_moe_pre_reorder(
77128 int64_t end_expert_id,
78129 int64_t topk,
79130 bool use_per_token_if_dynamic) {
80- int total_blocks = input.size (0 );
81- int block_size = 512 ;
131+ const int total_blocks = input.size (0 );
132+ const int block_size = 512 ;
82133 dim3 grid (total_blocks);
83134 dim3 block (block_size);
84135 int hidden_size = input.size (1 );
@@ -98,3 +149,33 @@ void ep_moe_pre_reorder(
98149 return true ;
99150 });
100151}
152+
153+ void ep_moe_post_reorder (
154+ torch::Tensor down_output,
155+ torch::Tensor output,
156+ torch::Tensor src2dst,
157+ torch::Tensor topk_ids,
158+ torch::Tensor topk_weights,
159+ int64_t start_expert_id,
160+ int64_t end_expert_id,
161+ int64_t topk) {
162+ const int total_tokens = output.size (0 );
163+ const int block_size = 512 ;
164+ dim3 grid (total_tokens);
165+ dim3 block (block_size);
166+ const int hidden_size = output.size (1 );
167+
168+ DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16 (down_output.scalar_type (), scalar_t , [&] {
169+ ep_post_reorder_cuda_kernel<scalar_t ><<<grid, block>>> (
170+ static_cast <scalar_t *>(down_output.data_ptr ()),
171+ static_cast <scalar_t *>(output.data_ptr ()),
172+ src2dst.data_ptr <int >(),
173+ topk_ids.data_ptr <int >(),
174+ static_cast <scalar_t *>(topk_weights.data_ptr ()),
175+ static_cast <int >(start_expert_id),
176+ static_cast <int >(end_expert_id),
177+ static_cast <int >(topk),
178+ hidden_size);
179+ return true ;
180+ });
181+ }
0 commit comments