@@ -255,37 +255,37 @@ void shuffle_rows(const torch::Tensor& input_tensor, const torch::Tensor& dst2sr
255255
256256template <typename scalar_t >
257257__global__ void apply_shuffle_mul_sum_kernel (
258- const scalar_t * __restrict__ input_tensor, // [m * topk, row_stride]
259- scalar_t * __restrict__ output_tensor, // [m, row_stride]
260- const int32_t * __restrict__ permutation, // [m * topk]
258+ const scalar_t * __restrict__ input_tensor, // [m * topk, k] (expert-major layout)
259+ scalar_t * __restrict__ output_tensor, // [m, k] (token-major layout)
260+ const int32_t * __restrict__ permutation, // [m * topk] (c_map: token-major-idx -> expert-major-idx)
261261 int m,
262262 int topk,
263263 int row_stride,
264- const scalar_t * __restrict__ factors) // [m * topk] or nullptr
264+ const scalar_t * __restrict__ factors) // [m * topk] (topk_weights, token-major layout)
265265{
266- int i = blockIdx .x ; // [0, m * topk)
267- int d = threadIdx .x ; // [0, row_stride)
268-
269- if (i >= m || d >= row_stride) return ;
270-
271- scalar_t sum_val = 0.0 ;
272-
273- for (int j = 0 ; j < topk; ++j) {
274- int index_2d = i * topk + j;
275- int src_row = permutation[index_2d];
276- if (src_row >= m) continue ;
277-
278- scalar_t val = input_tensor[src_row * row_stride + d];
266+ int i = blockIdx .x ;
267+ if (i >= m) {
268+ return ;
269+ }
279270
280- scalar_t factor = 1.0 ;
281- if (factors != nullptr ) {
282- factor = factors[index_2d];
271+ // Grid-stride loop to ensure each thread handles multiple feature dimensions
272+ // if row_stride > blockDim.x..
273+ for (int d = threadIdx .x ; d < row_stride; d += blockDim .x ) {
274+ scalar_t sum_val = 0.0 ;
275+
276+ for (int j = 0 ; j < topk; ++j) {
277+ int token_major_idx = i * topk + j;
278+ int src_row = permutation[token_major_idx];
279+ scalar_t val = input_tensor[src_row * row_stride + d];
280+ scalar_t factor = 1.0 ;
281+ if (factors != nullptr ) {
282+ factor = factors[token_major_idx];
283+ }
284+ sum_val += factor * val;
283285 }
284286
285- sum_val += factor * val ;
287+ output_tensor[i * row_stride + d] = sum_val ;
286288 }
287-
288- output_tensor[i * row_stride + d] = sum_val;
289289}
290290
291291void get_apply_shuffle_mul_sum_caller (
0 commit comments