Skip to content

Commit ab1be02

Browse files
committed
update
1 parent 3b82528 commit ab1be02

File tree

1 file changed

+33
-28
lines changed

1 file changed

+33
-28
lines changed

sgl-kernel/csrc/moe/prepare_moe_input.cu

100755100644
Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
#include <cudaTypedefs.h>
33
#include <torch/all.h>
44

5+
#include <flashinfer/vec_dtypes.cuh>
56
#include <iostream>
67

78
#include "cutlass/array.h"
9+
#include "utils.h"
810

911
constexpr uint64_t THREADS_PER_EXPERT = 512;
1012

@@ -268,23 +270,34 @@ __global__ void apply_shuffle_mul_sum_kernel(
268270
return;
269271
}
270272

271-
// Grid-stride loop to ensure each thread handles multiple feature dimensions
272-
// if row_stride > blockDim.x..
273-
for (int d = threadIdx.x; d < row_stride; d += blockDim.x) {
274-
scalar_t sum_val = 0.0;
273+
constexpr uint32_t vec_size = 16 / sizeof(scalar_t);
274+
using vec_t = flashinfer::vec_t<scalar_t, vec_size>;
275+
int thread_idx = threadIdx.x;
276+
int stride = blockDim.x;
277+
278+
for (int d_vec_idx = thread_idx; d_vec_idx < row_stride / vec_size; d_vec_idx += stride) {
279+
int d = d_vec_idx * vec_size;
280+
vec_t sum_vec;
281+
sum_vec.fill(0.0f);
275282

276283
for (int j = 0; j < topk; ++j) {
277284
int token_major_idx = i * topk + j;
278285
int src_row = permutation[token_major_idx];
279-
scalar_t val = input_tensor[src_row * row_stride + d];
286+
287+
vec_t val_vec;
288+
val_vec.load(input_tensor + src_row * row_stride + d);
289+
280290
scalar_t factor = 1.0;
281291
if (factors != nullptr) {
282292
factor = factors[token_major_idx];
283293
}
284-
sum_val += factor * val;
285-
}
286294

287-
output_tensor[i * row_stride + d] = sum_val;
295+
#pragma unroll
296+
for (int k = 0; k < vec_size; ++k) {
297+
sum_vec[k] += factor * val_vec[k];
298+
}
299+
}
300+
sum_vec.store(output_tensor + i * row_stride + d);
288301
}
289302
}
290303

@@ -304,7 +317,11 @@ void get_apply_shuffle_mul_sum_caller(
304317

305318
TORCH_CHECK(permutation.size(0) == m * topk, "permutation size must match m * topk");
306319

307-
dim3 block(std::min(256, row_stride));
320+
auto scalar_type = output_tensor.scalar_type();
321+
uint32_t vec_size = 16 / sizeof(scalar_type);
322+
auto blockDim = std::min(row_stride / vec_size, 1024U);
323+
dim3 block(blockDim);
324+
308325
dim3 grid(m); // blockIdx.x = j, blockIdx.y = i
309326
auto stream = at::cuda::getCurrentCUDAStream(input_tensor.device().index());
310327

@@ -317,29 +334,17 @@ void get_apply_shuffle_mul_sum_caller(
317334
factors_ptr = factors_opt->data_ptr();
318335
}
319336

320-
if (output_tensor.scalar_type() == at::ScalarType::Half) {
321-
const at::Half* factor_data = static_cast<const at::Half*>(factors_ptr);
322-
apply_shuffle_mul_sum_kernel<at::Half><<<grid, block, 0, stream>>>(
323-
input_tensor.data_ptr<at::Half>(),
324-
output_tensor.data_ptr<at::Half>(),
337+
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(output_tensor.scalar_type(), scalar_t, [&] {
338+
apply_shuffle_mul_sum_kernel<scalar_t><<<grid, block, 0, stream>>>(
339+
static_cast<const scalar_t*>(input_tensor.data_ptr()),
340+
static_cast<scalar_t*>(output_tensor.data_ptr()),
325341
perm_ptr,
326342
m,
327343
topk,
328344
row_stride,
329-
static_cast<const at::Half*>(factors_ptr));
330-
} else if (output_tensor.scalar_type() == at::ScalarType::BFloat16) {
331-
const c10::BFloat16* factor_data = static_cast<const c10::BFloat16*>(factors_ptr);
332-
apply_shuffle_mul_sum_kernel<c10::BFloat16><<<grid, block, 0, stream>>>(
333-
input_tensor.data_ptr<c10::BFloat16>(),
334-
output_tensor.data_ptr<c10::BFloat16>(),
335-
perm_ptr,
336-
m,
337-
topk,
338-
row_stride,
339-
static_cast<const c10::BFloat16*>(factors_ptr));
340-
} else {
341-
TORCH_CHECK(false, "Unsupported output dtype for cast+mul kernel: ", output_tensor.scalar_type());
342-
}
345+
static_cast<const scalar_t*>(factors_ptr));
346+
return true;
347+
});
343348
}
344349

345350
/**

0 commit comments

Comments
 (0)