Skip to content

Commit 3b82528

Browse files
committed
fix: fix apply_shuffle_mul_sum
1 parent edc21cc commit 3b82528

File tree

1 file changed

+23
-23
lines changed

1 file changed

+23
-23
lines changed

sgl-kernel/csrc/moe/prepare_moe_input.cu

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -255,37 +255,37 @@ void shuffle_rows(const torch::Tensor& input_tensor, const torch::Tensor& dst2sr
255255

256256
template <typename scalar_t>
257257
__global__ void apply_shuffle_mul_sum_kernel(
258-
const scalar_t* __restrict__ input_tensor, // [m * topk, row_stride]
259-
scalar_t* __restrict__ output_tensor, // [m, row_stride]
260-
const int32_t* __restrict__ permutation, // [m * topk]
258+
const scalar_t* __restrict__ input_tensor, // [m * topk, k] (expert-major layout)
259+
scalar_t* __restrict__ output_tensor, // [m, k] (token-major layout)
260+
const int32_t* __restrict__ permutation, // [m * topk] (c_map: token-major-idx -> expert-major-idx)
261261
int m,
262262
int topk,
263263
int row_stride,
264-
const scalar_t* __restrict__ factors) // [m * topk] or nullptr
264+
const scalar_t* __restrict__ factors) // [m * topk] (topk_weights, token-major layout)
265265
{
266-
int i = blockIdx.x; // [0, m * topk)
267-
int d = threadIdx.x; // [0, row_stride)
268-
269-
if (i >= m || d >= row_stride) return;
270-
271-
scalar_t sum_val = 0.0;
272-
273-
for (int j = 0; j < topk; ++j) {
274-
int index_2d = i * topk + j;
275-
int src_row = permutation[index_2d];
276-
if (src_row >= m) continue;
277-
278-
scalar_t val = input_tensor[src_row * row_stride + d];
266+
int i = blockIdx.x;
267+
if (i >= m) {
268+
return;
269+
}
279270

280-
scalar_t factor = 1.0;
281-
if (factors != nullptr) {
282-
factor = factors[index_2d];
271+
// Grid-stride loop to ensure each thread handles multiple feature dimensions
272+
// if row_stride > blockDim.x..
273+
for (int d = threadIdx.x; d < row_stride; d += blockDim.x) {
274+
scalar_t sum_val = 0.0;
275+
276+
for (int j = 0; j < topk; ++j) {
277+
int token_major_idx = i * topk + j;
278+
int src_row = permutation[token_major_idx];
279+
scalar_t val = input_tensor[src_row * row_stride + d];
280+
scalar_t factor = 1.0;
281+
if (factors != nullptr) {
282+
factor = factors[token_major_idx];
283+
}
284+
sum_val += factor * val;
283285
}
284286

285-
sum_val += factor * val;
287+
output_tensor[i * row_stride + d] = sum_val;
286288
}
287-
288-
output_tensor[i * row_stride + d] = sum_val;
289289
}
290290

291291
void get_apply_shuffle_mul_sum_caller(

0 commit comments

Comments
 (0)