Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions vllm/model_executor/layers/quantization/utils/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,13 +864,26 @@ def w8a8_triton_block_scaled_mm(
# Default config
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_size[0]
# BLOCK_SIZE_K must be divisible by block_size[1]
# M-aware tuning for low-M decode: the previous default of
# BLOCK_SIZE_M=64 wastes 98% of the M-dim for single-request decode
# (M=1) and short MTP-style draft batches. Specialize only the M<=8
# case to keep blast radius small; larger M keeps the previous default.
# num_stages=3 is gated to non-ROCm because MI300/MI250X LDS (64 KB)
# is borderline for 3-stage Triton pipelining at typical [128,128]
# block sizes; on ROCm we keep num_stages=2 so the M<=8 branch still
# gets the BLOCK_SIZE_M=16 wave-quantisation win without LDS pressure.
if M <= 8:
block_m = 16
num_stages = 2 if current_platform.is_rocm() else 3
else:
block_m, num_stages = 64, 2
Comment on lines +875 to +879
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Setting num_stages=3 can lead to OutOfResources errors or performance regressions on ROCm, as seen in other Triton kernels in the codebase (e.g., in vllm/model_executor/layers/fused_moe/fused_moe.py). It is safer to use num_stages=2 for ROCm while keeping the optimized value for NVIDIA architectures.

Suggested change
if M <= 8:
block_m, num_stages = 16, 3
else:
block_m, num_stages = 64, 2
if M <= 8:
block_m, num_stages = 16, (2 if current_platform.is_rocm() else 3)
else:
block_m, num_stages = 64, 2

config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_M": block_m,
"BLOCK_SIZE_N": block_size[0],
"BLOCK_SIZE_K": block_size[1],
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2,
"num_stages": num_stages,
}

def grid(META):
Expand Down
Loading