Skip to content

Commit d4154c3

Browse files
jinzhen-linmgoin
andauthored
[Bugfix] fix moe marlin topk_weight loading (#18080)
Co-authored-by: mgoin <[email protected]>
1 parent 6685890 commit d4154c3

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

csrc/moe/marlin_moe_wna16/marlin_template.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -473,15 +473,15 @@ __global__ void Marlin(
473473
if (mul_topk_weights) {
474474
#pragma unroll
475475
for (int i = 0; i < 4; i++) {
476+
int idx = tid4 * 4 + i;
477+
idx = idx < block_num_valid_tokens ? idx : 0;
476478
if constexpr (w_type == vllm::kFE2M1f) {
477-
sh_block_topk_weights[tid4 * 4 + i] = __hmul2(
478-
global_scale,
479-
Dtype::num2num2(Dtype::float2num(
480-
topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]])));
479+
sh_block_topk_weights[idx] = __hmul2(
480+
global_scale, Dtype::num2num2(Dtype::float2num(
481+
topk_weights_ptr[sh_block_sorted_ids[idx]])));
481482
} else {
482-
sh_block_topk_weights[tid4 * 4 + i] =
483-
Dtype::num2num2(Dtype::float2num(
484-
topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]]));
483+
sh_block_topk_weights[idx] = Dtype::num2num2(
484+
Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[idx]]));
485485
}
486486
}
487487
}

0 commit comments

Comments
 (0)