Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 11 additions & 25 deletions tests/kernels/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

import torch

from vllm.model_executor.layers.quantization.utils.quant_utils import group_broadcast
from vllm.model_executor.layers.quantization.utils.quant_utils import (
get_fp8_min_max,
group_broadcast,
)
from vllm.platforms import current_platform
from vllm.utils.math_utils import round_up

# Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models. Here use 224.0 for rocm.
ROCM_FP8FNUZ_MAX = 224.0
FP8_DTYPE = current_platform.fp8_dtype()


Expand All @@ -25,16 +25,12 @@ def ref_dynamic_per_token_quant(
if scale_ub is not None:
assert quant_dtype == FP8_DTYPE

qtype_traits = (
torch.iinfo(quant_dtype)
if quant_dtype == torch.int8
else torch.finfo(quant_dtype)
)
use_fp8fnuz = (
current_platform.is_fp8_fnuz() and quant_dtype == current_platform.fp8_dtype()
)
qtype_traits_max = ROCM_FP8FNUZ_MAX if use_fp8fnuz else qtype_traits.max
qtype_traits_min = -ROCM_FP8FNUZ_MAX if use_fp8fnuz else qtype_traits.min
if quant_dtype == torch.int8:
qtype_traits = torch.iinfo(quant_dtype)
qtype_traits_min = qtype_traits.min
qtype_traits_max = qtype_traits.max
else:
qtype_traits_min, qtype_traits_max = get_fp8_min_max()
qtype_max = as_float32_tensor(qtype_traits_max)
s_1 = as_float32_tensor(1.0)
s_512 = as_float32_tensor(512.0)
Expand Down Expand Up @@ -72,17 +68,7 @@ def ref_dynamic_per_token_quant(
def ref_dynamic_per_tensor_fp8_quant(
x: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
fp8_traits = torch.finfo(FP8_DTYPE)
fp8_traits_max = (
ROCM_FP8FNUZ_MAX
if current_platform.is_rocm() and current_platform.is_fp8_fnuz()
else fp8_traits.max
)
fp8_traits_min = (
-ROCM_FP8FNUZ_MAX
if current_platform.is_rocm() and current_platform.is_fp8_fnuz()
else fp8_traits.min
)
fp8_traits_min, fp8_traits_max = get_fp8_min_max()
fp8_max = as_float32_tensor(fp8_traits_max)
one = as_float32_tensor(1.0)

Expand Down
65 changes: 65 additions & 0 deletions tests/kernels/quantization/test_fp8_min_max_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Unit tests for the get_fp8_min_max() helper function.

These tests verify the FP8 min/max value logic for both standard
and fnuz (ROCm MI300) dtype handling.
"""

from unittest.mock import patch

import pytest
import torch

from vllm.model_executor.layers.quantization.utils.quant_utils import (
get_fp8_min_max,
)


class TestGetFp8MinMax:
"""Test cases for get_fp8_min_max() function."""

@patch("vllm.model_executor.layers.quantization.utils.quant_utils.current_platform")
def test_standard_fp8_platform(self, mock_platform):
"""Test that standard FP8 platform uses PyTorch's finfo values."""
mock_platform.is_fp8_fnuz.return_value = False
mock_platform.fp8_dtype.return_value = torch.float8_e4m3fn

fp8_min, fp8_max = get_fp8_min_max()
finfo = torch.finfo(torch.float8_e4m3fn)

# Standard FP8 max is 448.0 for e4m3fn
assert fp8_max == finfo.max, f"Expected finfo.max={finfo.max}, got {fp8_max}"
assert fp8_min == finfo.min, f"Expected finfo.min={finfo.min}, got {fp8_min}"

@patch("vllm.model_executor.layers.quantization.utils.quant_utils.current_platform")
def test_fnuz_platform_returns_224(self, mock_platform):
"""Test that fnuz platform returns 224.0."""
mock_platform.is_fp8_fnuz.return_value = True

fp8_min, fp8_max = get_fp8_min_max()

# fnuz on ROCm MI300 should return 224.0, not 240.0
assert fp8_max == 224.0, f"Expected 224.0 for fnuz platform, got {fp8_max}"
assert fp8_min == -224.0, f"Expected -224.0 for fnuz platform, got {fp8_min}"

@patch("vllm.model_executor.layers.quantization.utils.quant_utils.current_platform")
def test_non_fnuz_platform_uses_finfo(self, mock_platform):
"""Test that non-fnuz platform uses finfo values."""
mock_platform.is_fp8_fnuz.return_value = False
mock_platform.fp8_dtype.return_value = torch.float8_e4m3fn

fp8_min, fp8_max = get_fp8_min_max()
finfo = torch.finfo(torch.float8_e4m3fn)

assert fp8_max == finfo.max, (
f"Non-fnuz platform should use finfo.max={finfo.max}, got {fp8_max}"
)
assert fp8_min == finfo.min, (
f"Non-fnuz platform should use finfo.min={finfo.min}, got {fp8_min}"
)


if __name__ == "__main__":
pytest.main([__file__, "-v"])
11 changes: 5 additions & 6 deletions vllm/model_executor/layers/quantization/input_quant_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
get_fp8_min_max,
)
from vllm.platforms import current_platform

# Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models. Here use 224.0 for fnuz on ROCm.
_FP8_DTYPE = current_platform.fp8_dtype()
_FP8_FINFO = torch.finfo(_FP8_DTYPE)
_FP8_MAX = 224.0 if current_platform.is_fp8_fnuz() else _FP8_FINFO.max
_FP8_MIN = -224.0 if current_platform.is_fp8_fnuz() else _FP8_FINFO.min
_FP8_MIN, _FP8_MAX = get_fp8_min_max()
_FP8_MIN_SCALING_FACTOR = 1.0 / (_FP8_MAX * 512.0)


Expand Down
12 changes: 5 additions & 7 deletions vllm/model_executor/layers/quantization/utils/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
get_fp8_min_max,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED,
)
Expand Down Expand Up @@ -748,12 +751,7 @@ def per_token_group_quant_fp8(
)
assert x.stride(-1) == 1, "`x` groups must be contiguous"

# Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models. Here use 224.0 for fnuz on ROCm
# platforms that use the torch.float8_e4mefnuz dtype.
finfo = torch.finfo(dtype)
fp8_min = -224.0 if current_platform.is_fp8_fnuz() else finfo.min
fp8_max = 224.0 if current_platform.is_fp8_fnuz() else finfo.max
fp8_min, fp8_max = get_fp8_min_max()

assert out_q is None or out_q.shape == x.shape
x_q = out_q
Expand Down
11 changes: 11 additions & 0 deletions vllm/model_executor/layers/quantization/utils/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,17 @@
FP4_DTYPE = torch.uint8


def get_fp8_min_max() -> tuple[float, float]:
"""Get the min and max values for FP8 quantization."""
# Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models on ROCm. Here, use 224.0 for fnuz
# on ROCm platforms that use the torch.float8_e4m3fnuz dtype.
if current_platform.is_fp8_fnuz():
return -224.0, 224.0
finfo = torch.finfo(current_platform.fp8_dtype())
return finfo.min, finfo.max


# Use proxy as NamedTuple direct subclasses cannot have static members
class _GroupShape(NamedTuple):
row: int
Expand Down
6 changes: 5 additions & 1 deletion vllm/utils/deep_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

import vllm.envs as envs
from vllm.logger import logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
get_fp8_min_max,
)
from vllm.platforms import current_platform
from vllm.utils.import_utils import has_deep_gemm
from vllm.utils.math_utils import cdiv
Expand Down Expand Up @@ -355,7 +358,8 @@ def per_block_cast_to_fp8(
x_padded[:m, :n] = x
x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
sf = x_amax / 224.0 if current_platform.is_fp8_fnuz() else x_amax / 448.0
_, fp8_max = get_fp8_min_max()
sf = x_amax / fp8_max
sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf
x_scaled = (x_view * (1.0 / sf)).to(fp8_dtype)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(
Expand Down