diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 9613b11d35e2..65cd9abb0060 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -864,13 +864,26 @@ def w8a8_triton_block_scaled_mm( # Default config # Block-wise quant: BLOCK_SIZE_N must be divisible by block_size[0] # BLOCK_SIZE_K must be divisible by block_size[1] + # M-aware tuning for low-M decode: the previous default of + # BLOCK_SIZE_M=64 wastes 98% of the M-dim for single-request decode + # (M=1) and short MTP-style draft batches. Specialize only the M<=8 + # case to keep blast radius small; larger M keeps the previous default. + # num_stages=3 is gated to non-ROCm because MI300/MI250X LDS (64 KB) + # is borderline for 3-stage Triton pipelining at typical [128,128] + # block sizes; on ROCm we keep num_stages=2 so the M<=8 branch still + # gets the BLOCK_SIZE_M=16 wave-quantisation win without LDS pressure. + if M <= 8: + block_m = 16 + num_stages = 2 if current_platform.is_rocm() else 3 + else: + block_m, num_stages = 64, 2 config = { - "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_M": block_m, "BLOCK_SIZE_N": block_size[0], "BLOCK_SIZE_K": block_size[1], "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 2, + "num_stages": num_stages, } def grid(META):