Skip to content

Commit f2e7158

Browse files
committed
update
1 parent ab1be02 commit f2e7158

File tree

2 files changed

+25
-5
lines changed

2 files changed

+25
-5
lines changed

python/sglang/srt/layers/moe/cutlass_moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,8 @@ def cutlass_fused_experts_fp8(
209209
)
210210

211211
result = torch.empty((m, k), device=device, dtype=out_dtype)
212-
return apply_shuffle_mul_sum(c2, result, c_map, topk_weights)
212+
apply_shuffle_mul_sum(c2, result, c_map, topk_weights.to(out_dtype))
213+
return result
213214

214215

215216
FLOAT4_E2M1_MAX = 6.0

sgl-kernel/csrc/moe/prepare_moe_input.cu

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,8 @@ __global__ void apply_shuffle_mul_sum_kernel(
271271
}
272272

273273
constexpr uint32_t vec_size = 16 / sizeof(scalar_t);
274-
using vec_t = flashinfer::vec_t<scalar_t, vec_size>;
274+
using t = float;
275+
using vec_t = flashinfer::vec_t<t, vec_size>;
275276
int thread_idx = threadIdx.x;
276277
int stride = blockDim.x;
277278

@@ -285,9 +286,9 @@ __global__ void apply_shuffle_mul_sum_kernel(
285286
int src_row = permutation[token_major_idx];
286287

287288
vec_t val_vec;
288-
val_vec.load(input_tensor + src_row * row_stride + d);
289+
val_vec.cast_load(input_tensor + src_row * row_stride + d);
289290

290-
scalar_t factor = 1.0;
291+
t factor = 1.0;
291292
if (factors != nullptr) {
292293
factor = factors[token_major_idx];
293294
}
@@ -297,7 +298,25 @@ __global__ void apply_shuffle_mul_sum_kernel(
297298
sum_vec[k] += factor * val_vec[k];
298299
}
299300
}
300-
sum_vec.store(output_tensor + i * row_stride + d);
301+
sum_vec.cast_store(output_tensor + i * row_stride + d);
302+
}
303+
304+
// remainder part
305+
int remainder_start = (row_stride / vec_size) * vec_size;
306+
for (int d = remainder_start + thread_idx; d < row_stride; d += stride) {
307+
t sum_val = 0.0;
308+
for (int j = 0; j < topk; ++j) {
309+
int token_major_idx = i * topk + j;
310+
int src_row = permutation[token_major_idx];
311+
t val = input_tensor[src_row * row_stride + d];
312+
313+
t factor = 1.0;
314+
if (factors != nullptr) {
315+
factor = factors[token_major_idx];
316+
}
317+
sum_val += factor * val;
318+
}
319+
output_tensor[i * row_stride + d] = sum_val;
301320
}
302321
}
303322

0 commit comments

Comments
 (0)