diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index d1b25aa92e8d..ad13c17418b5 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -532,22 +532,37 @@ def fused_moe_kernel( a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk - # Router weight multiplication MUST happen in float32 before precision - # conversion for numerical stability (especially critical on ROCm). - if MUL_ROUTED_WEIGHT: - moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) - accumulator = accumulator * moe_weight[:, None] - + # Dequantization for supported quantization schemes: + # - int8_w8a16 + # - fp8_w8a8 + # - int8_w8a8 + # Accumulator and scalings are in float32 to preserve numerical accuracy. if use_int8_w8a16: accumulator = accumulator * b_scale elif (use_fp8_w8a8 or use_int8_w8a8) and not (group_k > 0 and group_n > 0): accumulator = accumulator * a_scale * b_scale - # Bias is added AFTER dequantization since bias is typically stored in - # the output dtype and should not be scaled by quantization factors. + # Bias addition: + # Bias must be applied after dequantization: + # - Since bias is typically not quantized + # - Bias should not be scaled by quantization factors if HAS_BIAS: - accumulator = accumulator + bias[None, :] + accumulator += bias[None, :] + + # Router (MoE) weight multiplication: + # This multiplication MUST be performed in float32 before any precision + # conversion to ensure numerical stability, which is especially critical + # on ROCm platforms. + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load( + topk_weights_ptr + offs_token, + mask=token_mask, + other=0, + ) + accumulator *= moe_weight[:, None] + # Final precision conversion: + # Cast once at the end to the desired compute/output dtype. accumulator = accumulator.to(compute_type) # -----------------------------------------------------------