From 73234d9b6355380083e813b140b3efcc7911a803 Mon Sep 17 00:00:00 2001 From: xuebwang-amd Date: Sun, 4 Jan 2026 09:04:17 +0000 Subject: [PATCH 1/2] fix bias adding for triton implemented fused_moe_kernel Signed-off-by: xuebwang-amd --- vllm/model_executor/layers/fused_moe/fused_moe.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index b434780e19a2..f25583096178 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -518,11 +518,7 @@ def fused_moe_kernel( # Advance the ptrs to the next K block. a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk - if HAS_BIAS: - accumulator = accumulator + bias[None, :] - if MUL_ROUTED_WEIGHT: - moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) - accumulator = accumulator * moe_weight[:, None] + if use_int8_w8a16: accumulator = (accumulator * b_scale).to(compute_type) elif use_fp8_w8a8 or use_int8_w8a8: @@ -533,6 +529,13 @@ def fused_moe_kernel( else: accumulator = accumulator.to(compute_type) + # Since bias is typically not quantized, it's added after dequantization. + if HAS_BIAS: + accumulator = accumulator + bias[None, :] + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) + accumulator = accumulator * moe_weight[:, None] + # ----------------------------------------------------------- # Write back the block of the output offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) From 08bb2408d29a1b53fe3a3cb5b8bea1b3dee36977 Mon Sep 17 00:00:00 2001 From: xuebwang-amd Date: Sun, 4 Jan 2026 09:06:04 +0000 Subject: [PATCH 2/2] a tiny code lint issue Signed-off-by: xuebwang-amd --- vllm/model_executor/layers/fused_moe/fused_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index f25583096178..1e2f96c86a55 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -518,7 +518,7 @@ def fused_moe_kernel( # Advance the ptrs to the next K block. a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk - + if use_int8_w8a16: accumulator = (accumulator * b_scale).to(compute_type) elif use_fp8_w8a8 or use_int8_w8a8: