@@ -271,7 +271,8 @@ __global__ void apply_shuffle_mul_sum_kernel(
271271 }
272272
273273 constexpr uint32_t vec_size = 16 / sizeof (scalar_t );
274- using vec_t = flashinfer::vec_t <scalar_t , vec_size>;
274+ using t = float ;
275+ using vec_t = flashinfer::vec_t <t, vec_size>;
275276 int thread_idx = threadIdx .x ;
276277 int stride = blockDim .x ;
277278
@@ -285,9 +286,9 @@ __global__ void apply_shuffle_mul_sum_kernel(
285286 int src_row = permutation[token_major_idx];
286287
287288 vec_t val_vec;
288- val_vec.load (input_tensor + src_row * row_stride + d);
289+ val_vec.cast_load (input_tensor + src_row * row_stride + d);
289290
290- scalar_t factor = 1.0 ;
291+ t factor = 1.0 ;
291292 if (factors != nullptr ) {
292293 factor = factors[token_major_idx];
293294 }
@@ -297,7 +298,25 @@ __global__ void apply_shuffle_mul_sum_kernel(
297298 sum_vec[k] += factor * val_vec[k];
298299 }
299300 }
300- sum_vec.store (output_tensor + i * row_stride + d);
301+ sum_vec.cast_store (output_tensor + i * row_stride + d);
302+ }
303+
304+ // remainder part
305+ int remainder_start = (row_stride / vec_size) * vec_size;
306+ for (int d = remainder_start + thread_idx; d < row_stride; d += stride) {
307+ t sum_val = 0.0 ;
308+ for (int j = 0 ; j < topk; ++j) {
309+ int token_major_idx = i * topk + j;
310+ int src_row = permutation[token_major_idx];
311+ t val = input_tensor[src_row * row_stride + d];
312+
313+ t factor = 1.0 ;
314+ if (factors != nullptr ) {
315+ factor = factors[token_major_idx];
316+ }
317+ sum_val += factor * val;
318+ }
319+ output_tensor[i * row_stride + d] = sum_val;
301320 }
302321}
303322
0 commit comments