diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 1e2f96c86a55..76ba7278b202 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -520,14 +520,9 @@ def fused_moe_kernel( b_ptrs += BLOCK_SIZE_K * stride_bk if use_int8_w8a16: - accumulator = (accumulator * b_scale).to(compute_type) - elif use_fp8_w8a8 or use_int8_w8a8: - if group_k > 0 and group_n > 0: - accumulator = accumulator.to(compute_type) - else: - accumulator = (accumulator * a_scale * b_scale).to(compute_type) - else: - accumulator = accumulator.to(compute_type) + accumulator = accumulator * b_scale + elif (use_fp8_w8a8 or use_int8_w8a8) and not (group_k > 0 and group_n > 0): + accumulator = accumulator * a_scale * b_scale # Since bias is typically not quantized, it's added after dequantization. if HAS_BIAS: @@ -536,6 +531,8 @@ def fused_moe_kernel( moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) accumulator = accumulator * moe_weight[:, None] + accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- # Write back the block of the output offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)