22#include < cudaTypedefs.h>
33#include < torch/all.h>
44
5+ #include < flashinfer/vec_dtypes.cuh>
56#include < iostream>
67
78#include " cutlass/array.h"
9+ #include " utils.h"
810
911constexpr uint64_t THREADS_PER_EXPERT = 512 ;
1012
@@ -268,23 +270,34 @@ __global__ void apply_shuffle_mul_sum_kernel(
268270 return ;
269271 }
270272
271- // Grid-stride loop to ensure each thread handles multiple feature dimensions
272- // if row_stride > blockDim.x..
273- for (int d = threadIdx .x ; d < row_stride; d += blockDim .x ) {
274- scalar_t sum_val = 0.0 ;
273+ constexpr uint32_t vec_size = 16 / sizeof (scalar_t );
274+ using vec_t = flashinfer::vec_t <scalar_t , vec_size>;
275+ int thread_idx = threadIdx .x ;
276+ int stride = blockDim .x ;
277+
278+ for (int d_vec_idx = thread_idx; d_vec_idx < row_stride / vec_size; d_vec_idx += stride) {
279+ int d = d_vec_idx * vec_size;
280+ vec_t sum_vec;
281+ sum_vec.fill (0 .0f );
275282
276283 for (int j = 0 ; j < topk; ++j) {
277284 int token_major_idx = i * topk + j;
278285 int src_row = permutation[token_major_idx];
279- scalar_t val = input_tensor[src_row * row_stride + d];
286+
287+ vec_t val_vec;
288+ val_vec.load (input_tensor + src_row * row_stride + d);
289+
280290 scalar_t factor = 1.0 ;
281291 if (factors != nullptr ) {
282292 factor = factors[token_major_idx];
283293 }
284- sum_val += factor * val;
285- }
286294
287- output_tensor[i * row_stride + d] = sum_val;
295+ #pragma unroll
296+ for (int k = 0 ; k < vec_size; ++k) {
297+ sum_vec[k] += factor * val_vec[k];
298+ }
299+ }
300+ sum_vec.store (output_tensor + i * row_stride + d);
288301 }
289302}
290303
@@ -304,7 +317,11 @@ void get_apply_shuffle_mul_sum_caller(
304317
305318 TORCH_CHECK (permutation.size (0 ) == m * topk, " permutation size must match m * topk" );
306319
307- dim3 block (std::min (256 , row_stride));
320+ auto scalar_type = output_tensor.scalar_type ();
321+ uint32_t vec_size = 16 / sizeof (scalar_type);
322+ auto blockDim = std::min (row_stride / vec_size, 1024U );
323+ dim3 block (blockDim );
324+
308325 dim3 grid (m); // blockIdx.x = j, blockIdx.y = i
309326 auto stream = at::cuda::getCurrentCUDAStream (input_tensor.device ().index ());
310327
@@ -317,29 +334,17 @@ void get_apply_shuffle_mul_sum_caller(
317334 factors_ptr = factors_opt->data_ptr ();
318335 }
319336
320- if (output_tensor.scalar_type () == at::ScalarType::Half) {
321- const at::Half* factor_data = static_cast <const at::Half*>(factors_ptr);
322- apply_shuffle_mul_sum_kernel<at::Half><<<grid, block, 0 , stream>>> (
323- input_tensor.data_ptr <at::Half>(),
324- output_tensor.data_ptr <at::Half>(),
337+ DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16 (output_tensor.scalar_type (), scalar_t , [&] {
338+ apply_shuffle_mul_sum_kernel<scalar_t ><<<grid, block, 0 , stream>>> (
339+ static_cast <const scalar_t *>(input_tensor.data_ptr ()),
340+ static_cast <scalar_t *>(output_tensor.data_ptr ()),
325341 perm_ptr,
326342 m,
327343 topk,
328344 row_stride,
329- static_cast <const at::Half*>(factors_ptr));
330- } else if (output_tensor.scalar_type () == at::ScalarType::BFloat16) {
331- const c10::BFloat16* factor_data = static_cast <const c10::BFloat16*>(factors_ptr);
332- apply_shuffle_mul_sum_kernel<c10::BFloat16><<<grid, block, 0 , stream>>> (
333- input_tensor.data_ptr <c10::BFloat16>(),
334- output_tensor.data_ptr <c10::BFloat16>(),
335- perm_ptr,
336- m,
337- topk,
338- row_stride,
339- static_cast <const c10::BFloat16*>(factors_ptr));
340- } else {
341- TORCH_CHECK (false , " Unsupported output dtype for cast+mul kernel: " , output_tensor.scalar_type ());
342- }
345+ static_cast <const scalar_t *>(factors_ptr));
346+ return true ;
347+ });
343348}
344349
345350/* *
0 commit comments