Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,22 +532,37 @@ def fused_moe_kernel(
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk

# Router weight multiplication MUST happen in float32 before precision
# conversion for numerical stability (especially critical on ROCm).
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
accumulator = accumulator * moe_weight[:, None]

# Dequantization for supported quantization schemes:
# - int8_w8a16
# - fp8_w8a8
# - int8_w8a8
# Accumulator and scalings are in float32 to preserve numerical accuracy.
if use_int8_w8a16:
accumulator = accumulator * b_scale
elif (use_fp8_w8a8 or use_int8_w8a8) and not (group_k > 0 and group_n > 0):
accumulator = accumulator * a_scale * b_scale

# Bias is added AFTER dequantization since bias is typically stored in
# the output dtype and should not be scaled by quantization factors.
# Bias addition:
# Bias must be applied after dequantization:
# - Since bias is typically not quantized
# - Bias should not be scaled by quantization factors
if HAS_BIAS:
accumulator = accumulator + bias[None, :]
accumulator += bias[None, :]

# Router (MoE) weight multiplication:
# This multiplication MUST be performed in float32 before any precision
# conversion to ensure numerical stability, which is especially critical
# on ROCm platforms.
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(
topk_weights_ptr + offs_token,
mask=token_mask,
other=0,
)
accumulator *= moe_weight[:, None]

# Final precision conversion:
# Cast once at the end to the desired compute/output dtype.
accumulator = accumulator.to(compute_type)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate type conversion left after refactoring

Low Severity

The code contains two consecutive identical calls to accumulator.to(compute_type) at lines 566 and 568. The comment explicitly says "Cast once at the end" but the code performs the cast twice. This appears to be a refactoring artifact where the new line at 566 was added, but the original line at 568 was not removed. While functionally harmless (the conversion is idempotent), this is redundant code that contradicts the comment and may confuse future maintainers.

Fix in Cursor Fix in Web

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicated line is removed.


# -----------------------------------------------------------
Expand Down
Loading