From abb9dfda4c28bc6233ce5ad97572b9d58cbd050e Mon Sep 17 00:00:00 2001 From: c0de128 Date: Sun, 21 Dec 2025 20:13:16 -0600 Subject: [PATCH 1/7] [Bugfix][Hardware][AMD] Consolidate FP8 min/max values into helper function Add get_fp8_min_max() helper in quant_utils.py to centralize the FP8 min/max value logic for ROCm fnuz dtype handling. On ROCm with torch.float8_e4m3fnuz, using PyTorch's default finfo.max (240.0) causes accuracy issues with dynamic quantization. The correct value is 224.0 for fnuz dtype. This change: - Adds get_fp8_min_max(dtype) helper returning (fp8_min, fp8_max) tuple - Updates input_quant_fp8.py to use the helper - Updates fp8_utils.py per_token_group_quant_fp8() to use the helper - Updates deep_gemm.py per_block_cast_to_fp8() to use the helper - Updates tests/kernels/quant_utils.py to use the helper Fixes #30360 Signed-off-by: c0de128 --- tests/kernels/quant_utils.py | 36 ++++++------------- .../layers/quantization/input_quant_fp8.py | 11 +++--- .../layers/quantization/utils/fp8_utils.py | 12 +++---- .../layers/quantization/utils/quant_utils.py | 22 ++++++++++++ vllm/utils/deep_gemm.py | 6 +++- 5 files changed, 48 insertions(+), 39 deletions(-) diff --git a/tests/kernels/quant_utils.py b/tests/kernels/quant_utils.py index 7927bd0d200d..479338b990a2 100644 --- a/tests/kernels/quant_utils.py +++ b/tests/kernels/quant_utils.py @@ -4,13 +4,13 @@ import torch -from vllm.model_executor.layers.quantization.utils.quant_utils import group_broadcast +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + get_fp8_min_max, + group_broadcast, +) from vllm.platforms import current_platform from vllm.utils.math_utils import round_up -# Using the default value (240.0) from pytorch will cause accuracy -# issue on dynamic quantization models. Here use 224.0 for rocm. -ROCM_FP8FNUZ_MAX = 224.0 FP8_DTYPE = current_platform.fp8_dtype() @@ -25,16 +25,12 @@ def ref_dynamic_per_token_quant( if scale_ub is not None: assert quant_dtype == FP8_DTYPE - qtype_traits = ( - torch.iinfo(quant_dtype) - if quant_dtype == torch.int8 - else torch.finfo(quant_dtype) - ) - use_fp8fnuz = ( - current_platform.is_fp8_fnuz() and quant_dtype == current_platform.fp8_dtype() - ) - qtype_traits_max = ROCM_FP8FNUZ_MAX if use_fp8fnuz else qtype_traits.max - qtype_traits_min = -ROCM_FP8FNUZ_MAX if use_fp8fnuz else qtype_traits.min + if quant_dtype == torch.int8: + qtype_traits = torch.iinfo(quant_dtype) + qtype_traits_min = qtype_traits.min + qtype_traits_max = qtype_traits.max + else: + qtype_traits_min, qtype_traits_max = get_fp8_min_max(quant_dtype) qtype_max = as_float32_tensor(qtype_traits_max) s_1 = as_float32_tensor(1.0) s_512 = as_float32_tensor(512.0) @@ -72,17 +68,7 @@ def ref_dynamic_per_token_quant( def ref_dynamic_per_tensor_fp8_quant( x: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - fp8_traits = torch.finfo(FP8_DTYPE) - fp8_traits_max = ( - ROCM_FP8FNUZ_MAX - if current_platform.is_rocm() and current_platform.is_fp8_fnuz() - else fp8_traits.max - ) - fp8_traits_min = ( - -ROCM_FP8FNUZ_MAX - if current_platform.is_rocm() and current_platform.is_fp8_fnuz() - else fp8_traits.min - ) + fp8_traits_min, fp8_traits_max = get_fp8_min_max(FP8_DTYPE) fp8_max = as_float32_tensor(fp8_traits_max) one = as_float32_tensor(1.0) diff --git a/vllm/model_executor/layers/quantization/input_quant_fp8.py b/vllm/model_executor/layers/quantization/input_quant_fp8.py index 7994c838ad54..f57deceaf6ca 100644 --- a/vllm/model_executor/layers/quantization/input_quant_fp8.py +++ b/vllm/model_executor/layers/quantization/input_quant_fp8.py @@ -7,15 +7,14 @@ from vllm import _custom_ops as ops from vllm._aiter_ops import rocm_aiter_ops from vllm.model_executor.custom_op import CustomOp -from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, + get_fp8_min_max, +) from vllm.platforms import current_platform -# Using the default value (240.0) from pytorch will cause accuracy -# issue on dynamic quantization models. Here use 224.0 for fnuz on ROCm. _FP8_DTYPE = current_platform.fp8_dtype() -_FP8_FINFO = torch.finfo(_FP8_DTYPE) -_FP8_MAX = 224.0 if current_platform.is_fp8_fnuz() else _FP8_FINFO.max -_FP8_MIN = -224.0 if current_platform.is_fp8_fnuz() else _FP8_FINFO.min +_FP8_MIN, _FP8_MAX = get_fp8_min_max(_FP8_DTYPE) _FP8_MIN_SCALING_FACTOR = 1.0 / (_FP8_MAX * 512.0) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index de6a1e8c1aa7..55df073be444 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -15,7 +15,10 @@ from vllm._aiter_ops import rocm_aiter_ops from vllm.logger import init_logger from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 -from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, + get_fp8_min_max, +) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( CUTLASS_BLOCK_FP8_SUPPORTED, ) @@ -748,12 +751,7 @@ def per_token_group_quant_fp8( ) assert x.stride(-1) == 1, "`x` groups must be contiguous" - # Using the default value (240.0) from pytorch will cause accuracy - # issue on dynamic quantization models. Here use 224.0 for fnuz on ROCm - # platforms that use the torch.float8_e4mefnuz dtype. - finfo = torch.finfo(dtype) - fp8_min = -224.0 if current_platform.is_fp8_fnuz() else finfo.min - fp8_max = 224.0 if current_platform.is_fp8_fnuz() else finfo.max + fp8_min, fp8_max = get_fp8_min_max(dtype) assert out_q is None or out_q.shape == x.shape x_q = out_q diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index d01263f82007..9829972cc365 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -19,6 +19,28 @@ FP4_DTYPE = torch.uint8 +def get_fp8_min_max(dtype: torch.dtype | None = None) -> tuple[float, float]: + """ + Get the min and max values for FP8 quantization. + + On ROCm with torch.float8_e4m3fnuz (fnuz), the default PyTorch finfo.max + (240.0) causes accuracy issues with dynamic quantization models. + Use 224.0 instead for fnuz dtype. + + Args: + dtype: FP8 dtype (defaults to platform's FP8 dtype if None) + + Returns: + Tuple of (fp8_min, fp8_max) values + """ + if dtype is None: + dtype = FP8_DTYPE + finfo = torch.finfo(dtype) + if current_platform.is_fp8_fnuz(): + return -224.0, 224.0 + return finfo.min, finfo.max + + # Use proxy as NamedTuple direct subclasses cannot have static members class _GroupShape(NamedTuple): row: int diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 56c9ca361eae..e66444228203 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -16,6 +16,9 @@ import vllm.envs as envs from vllm.logger import logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + get_fp8_min_max, +) from vllm.platforms import current_platform from vllm.utils.import_utils import has_deep_gemm from vllm.utils.math_utils import cdiv @@ -355,7 +358,8 @@ def per_block_cast_to_fp8( x_padded[:m, :n] = x x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n) x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) - sf = x_amax / 224.0 if current_platform.is_fp8_fnuz() else x_amax / 448.0 + _, fp8_max = get_fp8_min_max(fp8_dtype) + sf = x_amax / fp8_max sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf x_scaled = (x_view * (1.0 / sf)).to(fp8_dtype) return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view( From e2e6438167e737dd8408fdfdd0190ecfb270295a Mon Sep 17 00:00:00 2001 From: c0de128 Date: Sun, 21 Dec 2025 21:06:01 -0600 Subject: [PATCH 2/7] Fix dtype check in get_fp8_min_max helper MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Address review feedback: Only apply the 224.0 override when both: 1. Platform supports fnuz (is_fp8_fnuz()) 2. The dtype is actually torch.float8_e4m3fnuz This prevents incorrect min/max values when a non-fnuz dtype is explicitly passed on a platform that supports fnuz. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Sonnet 4.5 Signed-off-by: c0de128 --- vllm/model_executor/layers/quantization/utils/quant_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 9829972cc365..d2eb768ede9a 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -36,7 +36,8 @@ def get_fp8_min_max(dtype: torch.dtype | None = None) -> tuple[float, float]: if dtype is None: dtype = FP8_DTYPE finfo = torch.finfo(dtype) - if current_platform.is_fp8_fnuz(): + # Only apply the 224.0 override for the actual fnuz dtype on fnuz platform + if current_platform.is_fp8_fnuz() and dtype == torch.float8_e4m3fnuz: return -224.0, 224.0 return finfo.min, finfo.max From fb90db603d4c1aa9acdc3f51bdaf1a978df134d3 Mon Sep 17 00:00:00 2001 From: c0de128 Date: Tue, 23 Dec 2025 09:37:41 -0600 Subject: [PATCH 3/7] Add unit tests for get_fp8_min_max() helper function MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add test_fp8_min_max_helper.py with mocked unit tests that verify: - Standard FP8 dtype uses PyTorch's finfo values - fnuz dtype on fnuz platform (MI300) returns 224.0, not 240.0 - Standard dtype on fnuz platform uses finfo values - fnuz dtype on non-fnuz platform uses finfo values These tests use mocking to verify the logic without requiring actual ROCm hardware. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 Signed-off-by: c0de128 --- .../quantization/test_fp8_min_max_helper.py | 93 +++++++++++++++++++ 1 file changed, 93 insertions(+) create mode 100644 tests/kernels/quantization/test_fp8_min_max_helper.py diff --git a/tests/kernels/quantization/test_fp8_min_max_helper.py b/tests/kernels/quantization/test_fp8_min_max_helper.py new file mode 100644 index 000000000000..b19637b3adef --- /dev/null +++ b/tests/kernels/quantization/test_fp8_min_max_helper.py @@ -0,0 +1,93 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Unit tests for the get_fp8_min_max() helper function. + +These tests verify the FP8 min/max value logic for both standard +and fnuz (ROCm MI300) dtype handling. +""" + +from unittest.mock import patch + +import pytest +import torch + + +class TestGetFp8MinMax: + """Test cases for get_fp8_min_max() function.""" + + def test_standard_fp8_dtype(self): + """Test that standard FP8 dtype uses PyTorch's finfo values.""" + from vllm.model_executor.layers.quantization.utils.quant_utils import ( + get_fp8_min_max, + ) + + # For standard float8_e4m3fn, should return finfo values + fp8_min, fp8_max = get_fp8_min_max(torch.float8_e4m3fn) + finfo = torch.finfo(torch.float8_e4m3fn) + + # Standard FP8 max is 448.0 for e4m3fn + assert fp8_max == finfo.max, f"Expected finfo.max={finfo.max}, got {fp8_max}" + assert fp8_min == finfo.min, f"Expected finfo.min={finfo.min}, got {fp8_min}" + + @patch("vllm.model_executor.layers.quantization.utils.quant_utils.current_platform") + def test_fnuz_fp8_dtype_on_fnuz_platform(self, mock_platform): + """Test that fnuz dtype on fnuz platform returns 224.0.""" + mock_platform.is_fp8_fnuz.return_value = True + mock_platform.fp8_dtype.return_value = torch.float8_e4m3fnuz + + # Re-import to use mocked platform + from importlib import reload + + import vllm.model_executor.layers.quantization.utils.quant_utils as qu + + reload(qu) + + fp8_min, fp8_max = qu.get_fp8_min_max(torch.float8_e4m3fnuz) + + # fnuz on ROCm MI300 should return 224.0, not 240.0 + assert fp8_max == 224.0, ( + f"Expected 224.0 for fnuz on fnuz platform, got {fp8_max}" + ) + assert fp8_min == -224.0, ( + f"Expected -224.0 for fnuz on fnuz platform, got {fp8_min}" + ) + + @patch("vllm.model_executor.layers.quantization.utils.quant_utils.current_platform") + def test_standard_dtype_on_fnuz_platform(self, mock_platform): + """Test that standard dtype on fnuz platform uses finfo values.""" + mock_platform.is_fp8_fnuz.return_value = True + + from vllm.model_executor.layers.quantization.utils.quant_utils import ( + get_fp8_min_max, + ) + + # Standard e4m3fn dtype should use finfo even on fnuz platform + fp8_min, fp8_max = get_fp8_min_max(torch.float8_e4m3fn) + finfo = torch.finfo(torch.float8_e4m3fn) + + assert fp8_max == finfo.max, ( + f"Standard dtype should use finfo.max={finfo.max}, got {fp8_max}" + ) + + @patch("vllm.model_executor.layers.quantization.utils.quant_utils.current_platform") + def test_fnuz_dtype_on_non_fnuz_platform(self, mock_platform): + """Test that fnuz dtype on non-fnuz platform uses finfo values.""" + mock_platform.is_fp8_fnuz.return_value = False + + from vllm.model_executor.layers.quantization.utils.quant_utils import ( + get_fp8_min_max, + ) + + # fnuz dtype on non-fnuz platform should use finfo + fp8_min, fp8_max = get_fp8_min_max(torch.float8_e4m3fnuz) + finfo = torch.finfo(torch.float8_e4m3fnuz) + + # Should be 240.0, not 224.0 (non-fnuz platform) + assert fp8_max == finfo.max, ( + f"Non-fnuz platform should use finfo.max={finfo.max}, got {fp8_max}" + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From 380b9ac609e6e9b7d7ddc14664d7a89f874349dd Mon Sep 17 00:00:00 2001 From: c0de128 Date: Mon, 5 Jan 2026 09:11:28 -0600 Subject: [PATCH 4/7] Fix test by removing problematic reload() calls Remove reload() usage which can cause module state issues and test isolation problems. Instead, import the function once at module level and let the @patch decorator handle mocking correctly. Signed-off-by: c0de128 --- .../quantization/test_fp8_min_max_helper.py | 26 ++++--------------- 1 file changed, 5 insertions(+), 21 deletions(-) diff --git a/tests/kernels/quantization/test_fp8_min_max_helper.py b/tests/kernels/quantization/test_fp8_min_max_helper.py index b19637b3adef..af468b892629 100644 --- a/tests/kernels/quantization/test_fp8_min_max_helper.py +++ b/tests/kernels/quantization/test_fp8_min_max_helper.py @@ -12,16 +12,16 @@ import pytest import torch +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + get_fp8_min_max, +) + class TestGetFp8MinMax: """Test cases for get_fp8_min_max() function.""" def test_standard_fp8_dtype(self): """Test that standard FP8 dtype uses PyTorch's finfo values.""" - from vllm.model_executor.layers.quantization.utils.quant_utils import ( - get_fp8_min_max, - ) - # For standard float8_e4m3fn, should return finfo values fp8_min, fp8_max = get_fp8_min_max(torch.float8_e4m3fn) finfo = torch.finfo(torch.float8_e4m3fn) @@ -34,16 +34,8 @@ def test_standard_fp8_dtype(self): def test_fnuz_fp8_dtype_on_fnuz_platform(self, mock_platform): """Test that fnuz dtype on fnuz platform returns 224.0.""" mock_platform.is_fp8_fnuz.return_value = True - mock_platform.fp8_dtype.return_value = torch.float8_e4m3fnuz - - # Re-import to use mocked platform - from importlib import reload - - import vllm.model_executor.layers.quantization.utils.quant_utils as qu - - reload(qu) - fp8_min, fp8_max = qu.get_fp8_min_max(torch.float8_e4m3fnuz) + fp8_min, fp8_max = get_fp8_min_max(torch.float8_e4m3fnuz) # fnuz on ROCm MI300 should return 224.0, not 240.0 assert fp8_max == 224.0, ( @@ -58,10 +50,6 @@ def test_standard_dtype_on_fnuz_platform(self, mock_platform): """Test that standard dtype on fnuz platform uses finfo values.""" mock_platform.is_fp8_fnuz.return_value = True - from vllm.model_executor.layers.quantization.utils.quant_utils import ( - get_fp8_min_max, - ) - # Standard e4m3fn dtype should use finfo even on fnuz platform fp8_min, fp8_max = get_fp8_min_max(torch.float8_e4m3fn) finfo = torch.finfo(torch.float8_e4m3fn) @@ -75,10 +63,6 @@ def test_fnuz_dtype_on_non_fnuz_platform(self, mock_platform): """Test that fnuz dtype on non-fnuz platform uses finfo values.""" mock_platform.is_fp8_fnuz.return_value = False - from vllm.model_executor.layers.quantization.utils.quant_utils import ( - get_fp8_min_max, - ) - # fnuz dtype on non-fnuz platform should use finfo fp8_min, fp8_max = get_fp8_min_max(torch.float8_e4m3fnuz) finfo = torch.finfo(torch.float8_e4m3fnuz) From f4683c4234ba10089a5b266bef0ce472f8fcd31b Mon Sep 17 00:00:00 2001 From: c0de128 Date: Mon, 5 Jan 2026 17:41:04 -0600 Subject: [PATCH 5/7] Simplify get_fp8_min_max() per code review feedback Address @rasmith's suggestions: - Remove dtype parameter, use current_platform.fp8_dtype() internally - Simplify logic: if is_fp8_fnuz() return -224,224 else use finfo - Update all call sites to use parameter-less function - Simplify tests to mock platform instead of passing dtype Signed-off-by: Kevin McKay Signed-off-by: c0de128 --- tests/kernels/quant_utils.py | 4 +- .../quantization/test_fp8_min_max_helper.py | 52 +++++++------------ .../layers/quantization/input_quant_fp8.py | 2 +- .../layers/quantization/utils/fp8_utils.py | 2 +- .../layers/quantization/utils/quant_utils.py | 16 ++---- vllm/utils/deep_gemm.py | 2 +- 6 files changed, 30 insertions(+), 48 deletions(-) diff --git a/tests/kernels/quant_utils.py b/tests/kernels/quant_utils.py index 479338b990a2..3d11413c5ad8 100644 --- a/tests/kernels/quant_utils.py +++ b/tests/kernels/quant_utils.py @@ -30,7 +30,7 @@ def ref_dynamic_per_token_quant( qtype_traits_min = qtype_traits.min qtype_traits_max = qtype_traits.max else: - qtype_traits_min, qtype_traits_max = get_fp8_min_max(quant_dtype) + qtype_traits_min, qtype_traits_max = get_fp8_min_max() qtype_max = as_float32_tensor(qtype_traits_max) s_1 = as_float32_tensor(1.0) s_512 = as_float32_tensor(512.0) @@ -68,7 +68,7 @@ def ref_dynamic_per_token_quant( def ref_dynamic_per_tensor_fp8_quant( x: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - fp8_traits_min, fp8_traits_max = get_fp8_min_max(FP8_DTYPE) + fp8_traits_min, fp8_traits_max = get_fp8_min_max() fp8_max = as_float32_tensor(fp8_traits_max) one = as_float32_tensor(1.0) diff --git a/tests/kernels/quantization/test_fp8_min_max_helper.py b/tests/kernels/quantization/test_fp8_min_max_helper.py index af468b892629..8cd68a3fef7e 100644 --- a/tests/kernels/quantization/test_fp8_min_max_helper.py +++ b/tests/kernels/quantization/test_fp8_min_max_helper.py @@ -20,10 +20,13 @@ class TestGetFp8MinMax: """Test cases for get_fp8_min_max() function.""" - def test_standard_fp8_dtype(self): - """Test that standard FP8 dtype uses PyTorch's finfo values.""" - # For standard float8_e4m3fn, should return finfo values - fp8_min, fp8_max = get_fp8_min_max(torch.float8_e4m3fn) + @patch("vllm.model_executor.layers.quantization.utils.quant_utils.current_platform") + def test_standard_fp8_platform(self, mock_platform): + """Test that standard FP8 platform uses PyTorch's finfo values.""" + mock_platform.is_fp8_fnuz.return_value = False + mock_platform.fp8_dtype.return_value = torch.float8_e4m3fn + + fp8_min, fp8_max = get_fp8_min_max() finfo = torch.finfo(torch.float8_e4m3fn) # Standard FP8 max is 448.0 for e4m3fn @@ -31,46 +34,31 @@ def test_standard_fp8_dtype(self): assert fp8_min == finfo.min, f"Expected finfo.min={finfo.min}, got {fp8_min}" @patch("vllm.model_executor.layers.quantization.utils.quant_utils.current_platform") - def test_fnuz_fp8_dtype_on_fnuz_platform(self, mock_platform): - """Test that fnuz dtype on fnuz platform returns 224.0.""" + def test_fnuz_platform_returns_224(self, mock_platform): + """Test that fnuz platform returns 224.0.""" mock_platform.is_fp8_fnuz.return_value = True - fp8_min, fp8_max = get_fp8_min_max(torch.float8_e4m3fnuz) + fp8_min, fp8_max = get_fp8_min_max() # fnuz on ROCm MI300 should return 224.0, not 240.0 - assert fp8_max == 224.0, ( - f"Expected 224.0 for fnuz on fnuz platform, got {fp8_max}" - ) - assert fp8_min == -224.0, ( - f"Expected -224.0 for fnuz on fnuz platform, got {fp8_min}" - ) + assert fp8_max == 224.0, f"Expected 224.0 for fnuz platform, got {fp8_max}" + assert fp8_min == -224.0, f"Expected -224.0 for fnuz platform, got {fp8_min}" @patch("vllm.model_executor.layers.quantization.utils.quant_utils.current_platform") - def test_standard_dtype_on_fnuz_platform(self, mock_platform): - """Test that standard dtype on fnuz platform uses finfo values.""" - mock_platform.is_fp8_fnuz.return_value = True - - # Standard e4m3fn dtype should use finfo even on fnuz platform - fp8_min, fp8_max = get_fp8_min_max(torch.float8_e4m3fn) - finfo = torch.finfo(torch.float8_e4m3fn) - - assert fp8_max == finfo.max, ( - f"Standard dtype should use finfo.max={finfo.max}, got {fp8_max}" - ) - - @patch("vllm.model_executor.layers.quantization.utils.quant_utils.current_platform") - def test_fnuz_dtype_on_non_fnuz_platform(self, mock_platform): - """Test that fnuz dtype on non-fnuz platform uses finfo values.""" + def test_non_fnuz_platform_uses_finfo(self, mock_platform): + """Test that non-fnuz platform uses finfo values.""" mock_platform.is_fp8_fnuz.return_value = False + mock_platform.fp8_dtype.return_value = torch.float8_e4m3fn - # fnuz dtype on non-fnuz platform should use finfo - fp8_min, fp8_max = get_fp8_min_max(torch.float8_e4m3fnuz) - finfo = torch.finfo(torch.float8_e4m3fnuz) + fp8_min, fp8_max = get_fp8_min_max() + finfo = torch.finfo(torch.float8_e4m3fn) - # Should be 240.0, not 224.0 (non-fnuz platform) assert fp8_max == finfo.max, ( f"Non-fnuz platform should use finfo.max={finfo.max}, got {fp8_max}" ) + assert fp8_min == finfo.min, ( + f"Non-fnuz platform should use finfo.min={finfo.min}, got {fp8_min}" + ) if __name__ == "__main__": diff --git a/vllm/model_executor/layers/quantization/input_quant_fp8.py b/vllm/model_executor/layers/quantization/input_quant_fp8.py index f57deceaf6ca..36508dff6577 100644 --- a/vllm/model_executor/layers/quantization/input_quant_fp8.py +++ b/vllm/model_executor/layers/quantization/input_quant_fp8.py @@ -14,7 +14,7 @@ from vllm.platforms import current_platform _FP8_DTYPE = current_platform.fp8_dtype() -_FP8_MIN, _FP8_MAX = get_fp8_min_max(_FP8_DTYPE) +_FP8_MIN, _FP8_MAX = get_fp8_min_max() _FP8_MIN_SCALING_FACTOR = 1.0 / (_FP8_MAX * 512.0) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 55df073be444..880b6a89ced3 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -751,7 +751,7 @@ def per_token_group_quant_fp8( ) assert x.stride(-1) == 1, "`x` groups must be contiguous" - fp8_min, fp8_max = get_fp8_min_max(dtype) + fp8_min, fp8_max = get_fp8_min_max() assert out_q is None or out_q.shape == x.shape x_q = out_q diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index d2eb768ede9a..5956bf984d9f 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -19,26 +19,20 @@ FP4_DTYPE = torch.uint8 -def get_fp8_min_max(dtype: torch.dtype | None = None) -> tuple[float, float]: +def get_fp8_min_max() -> tuple[float, float]: """ Get the min and max values for FP8 quantization. - On ROCm with torch.float8_e4m3fnuz (fnuz), the default PyTorch finfo.max - (240.0) causes accuracy issues with dynamic quantization models. + On ROCm platforms that use the torch.float8_e4m3fnuz dtype, the default + PyTorch finfo.max (240.0) causes accuracy issues with dynamic quantization. Use 224.0 instead for fnuz dtype. - Args: - dtype: FP8 dtype (defaults to platform's FP8 dtype if None) - Returns: Tuple of (fp8_min, fp8_max) values """ - if dtype is None: - dtype = FP8_DTYPE - finfo = torch.finfo(dtype) - # Only apply the 224.0 override for the actual fnuz dtype on fnuz platform - if current_platform.is_fp8_fnuz() and dtype == torch.float8_e4m3fnuz: + if current_platform.is_fp8_fnuz(): return -224.0, 224.0 + finfo = torch.finfo(current_platform.fp8_dtype()) return finfo.min, finfo.max diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index e66444228203..84e0fbb449fb 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -358,7 +358,7 @@ def per_block_cast_to_fp8( x_padded[:m, :n] = x x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n) x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) - _, fp8_max = get_fp8_min_max(fp8_dtype) + _, fp8_max = get_fp8_min_max() sf = x_amax / fp8_max sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf x_scaled = (x_view * (1.0 / sf)).to(fp8_dtype) From f5d01317fb531a253d9c3645749c0b91f2af89d4 Mon Sep 17 00:00:00 2001 From: c0de128 Date: Tue, 6 Jan 2026 07:58:25 -0600 Subject: [PATCH 6/7] Update comment format per reviewer feedback Signed-off-by: c0de128 --- .../layers/quantization/utils/quant_utils.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 5956bf984d9f..c598f8093a96 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -20,16 +20,10 @@ def get_fp8_min_max() -> tuple[float, float]: - """ - Get the min and max values for FP8 quantization. - - On ROCm platforms that use the torch.float8_e4m3fnuz dtype, the default - PyTorch finfo.max (240.0) causes accuracy issues with dynamic quantization. - Use 224.0 instead for fnuz dtype. - - Returns: - Tuple of (fp8_min, fp8_max) values - """ + """Get the min and max values for FP8 quantization.""" + # Using the default value (240.0) from pytorch will cause accuracy + # issue on dynamic quantization models on ROCm. Here, use 224.0 for fnuz on + # ROCm platforms that use the torch.float8_e4m3fnuz dtype. if current_platform.is_fp8_fnuz(): return -224.0, 224.0 finfo = torch.finfo(current_platform.fp8_dtype()) From c72dd1156c8c12ec8880e9212ffe732c55d56227 Mon Sep 17 00:00:00 2001 From: c0de128 Date: Tue, 6 Jan 2026 09:57:14 -0600 Subject: [PATCH 7/7] Adjust comment line wrapping Signed-off-by: c0de128 --- vllm/model_executor/layers/quantization/utils/quant_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index c598f8093a96..679b448842cd 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -22,8 +22,8 @@ def get_fp8_min_max() -> tuple[float, float]: """Get the min and max values for FP8 quantization.""" # Using the default value (240.0) from pytorch will cause accuracy - # issue on dynamic quantization models on ROCm. Here, use 224.0 for fnuz on - # ROCm platforms that use the torch.float8_e4m3fnuz dtype. + # issue on dynamic quantization models on ROCm. Here, use 224.0 for fnuz + # on ROCm platforms that use the torch.float8_e4m3fnuz dtype. if current_platform.is_fp8_fnuz(): return -224.0, 224.0 finfo = torch.finfo(current_platform.fp8_dtype())