Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,14 +520,9 @@ def fused_moe_kernel(
b_ptrs += BLOCK_SIZE_K * stride_bk

if use_int8_w8a16:
accumulator = (accumulator * b_scale).to(compute_type)
elif use_fp8_w8a8 or use_int8_w8a8:
if group_k > 0 and group_n > 0:
accumulator = accumulator.to(compute_type)
else:
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
else:
accumulator = accumulator.to(compute_type)
accumulator = accumulator * b_scale
elif (use_fp8_w8a8 or use_int8_w8a8) and not (group_k > 0 and group_n > 0):
accumulator = accumulator * a_scale * b_scale

# Since bias is typically not quantized, it's added after dequantization.
if HAS_BIAS:
Expand All @@ -536,6 +531,8 @@ def fused_moe_kernel(
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
accumulator = accumulator * moe_weight[:, None]

accumulator = accumulator.to(compute_type)

# -----------------------------------------------------------
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
Expand Down