diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 15ea9f7d60ff..8e4dde324f39 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -625,8 +625,9 @@ def silu_mul_per_token_group_quant_fp8_colmajor( M, N = input.size() N_2 = N // 2 + fp8_dtype = current_platform.fp8_dtype() if output is None: - output = torch.empty((M, N_2), dtype=torch.float8_e4m3fn, device=input.device) + output = torch.empty((M, N_2), dtype=fp8_dtype, device=input.device) output_scales = torch.empty( ((N_2 // GROUP_SIZE), M), dtype=torch.float32, device=input.device @@ -637,9 +638,12 @@ def silu_mul_per_token_group_quant_fp8_colmajor( assert M % BLOCK_M == 0 assert N_2 % BLOCK_N == 0 - finfo = torch.finfo(torch.float8_e4m3fn) - fp8_min = finfo.min - fp8_max = finfo.max + # Using the default value (240.0) from pytorch will cause accuracy + # issue on dynamic quantization models. Here use 224.0 for fnuz on ROCm + # platforms that use the torch.float8_e4m3fnuz dtype. + finfo = torch.finfo(fp8_dtype) + fp8_min = -224.0 if current_platform.is_fp8_fnuz() else finfo.min + fp8_max = 224.0 if current_platform.is_fp8_fnuz() else finfo.max # Force even division so we can avoid edgecases within the kernel. assert M % BLOCK_M == 0