From f3b8abed5d4429ae8adebb207bfa1d50c64cd99b Mon Sep 17 00:00:00 2001 From: c0de128 Date: Mon, 22 Dec 2025 14:20:11 -0600 Subject: [PATCH] [Bugfix][ROCm] Use platform fp8_dtype in silu_mul_per_token_group_quant_fp8_colmajor The function was hardcoding torch.float8_e4m3fn dtype and using its default min/max values. On ROCm platforms that use torch.float8_e4m3fnuz, this causes incorrect dtype and accuracy issues. This fix: - Uses current_platform.fp8_dtype() instead of hardcoded dtype - Applies the same ROCm-aware fp8 min/max logic (224.0 for fnuz) that is already used in per_token_group_quant_fp8() in the same file Signed-off-by: c0de128 --- .../layers/quantization/utils/fp8_utils.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 15ea9f7d60ff..8e4dde324f39 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -625,8 +625,9 @@ def silu_mul_per_token_group_quant_fp8_colmajor( M, N = input.size() N_2 = N // 2 + fp8_dtype = current_platform.fp8_dtype() if output is None: - output = torch.empty((M, N_2), dtype=torch.float8_e4m3fn, device=input.device) + output = torch.empty((M, N_2), dtype=fp8_dtype, device=input.device) output_scales = torch.empty( ((N_2 // GROUP_SIZE), M), dtype=torch.float32, device=input.device @@ -637,9 +638,12 @@ def silu_mul_per_token_group_quant_fp8_colmajor( assert M % BLOCK_M == 0 assert N_2 % BLOCK_N == 0 - finfo = torch.finfo(torch.float8_e4m3fn) - fp8_min = finfo.min - fp8_max = finfo.max + # Using the default value (240.0) from pytorch will cause accuracy + # issue on dynamic quantization models. Here use 224.0 for fnuz on ROCm + # platforms that use the torch.float8_e4m3fnuz dtype. + finfo = torch.finfo(fp8_dtype) + fp8_min = -224.0 if current_platform.is_fp8_fnuz() else finfo.min + fp8_max = 224.0 if current_platform.is_fp8_fnuz() else finfo.max # Force even division so we can avoid edgecases within the kernel. assert M % BLOCK_M == 0