File tree Expand file tree Collapse file tree 1 file changed +7
-7
lines changed
csrc/moe/marlin_moe_wna16 Expand file tree Collapse file tree 1 file changed +7
-7
lines changed Original file line number Diff line number Diff line change @@ -473,15 +473,15 @@ __global__ void Marlin(
473473 if (mul_topk_weights) {
474474 #pragma unroll
475475 for (int i = 0 ; i < 4 ; i++) {
476+ int idx = tid4 * 4 + i;
477+ idx = idx < block_num_valid_tokens ? idx : 0 ;
476478 if constexpr (w_type == vllm::kFE2M1f ) {
477- sh_block_topk_weights[tid4 * 4 + i] = __hmul2 (
478- global_scale,
479- Dtype::num2num2 (Dtype::float2num (
480- topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]])));
479+ sh_block_topk_weights[idx] = __hmul2 (
480+ global_scale, Dtype::num2num2 (Dtype::float2num (
481+ topk_weights_ptr[sh_block_sorted_ids[idx]])));
481482 } else {
482- sh_block_topk_weights[tid4 * 4 + i] =
483- Dtype::num2num2 (Dtype::float2num (
484- topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]]));
483+ sh_block_topk_weights[idx] = Dtype::num2num2 (
484+ Dtype::float2num (topk_weights_ptr[sh_block_sorted_ids[idx]]));
485485 }
486486 }
487487 }
You can’t perform that action at this time.
0 commit comments