Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -951,7 +951,7 @@ steps:
# Whisper needs spawn method to avoid deadlock
- VLLM_WORKER_MULTIPROC_METHOD=spawn python3 examples/offline_inference/audio_language.py --model-type whisper

- label: Blackwell Test # 21 min
- label: Blackwell Test # 23 min
timeout_in_minutes: 30
working_dir: "/vllm-workspace/"
gpu: b200
Expand Down Expand Up @@ -991,6 +991,8 @@ steps:
- pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py
- pytest -v -s tests/kernels/moe/test_flashinfer.py
- pytest -v -s tests/kernels/moe/test_cutedsl_moe.py
# e2e
- pytest -v -s tests/models/quantization/test_nvfp4.py

- label: Blackwell Fusion and Compile Tests # 30 min
timeout_in_minutes: 40
Expand Down
26 changes: 24 additions & 2 deletions tests/kernels/quantization/nvfp4_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,26 @@ def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
return out[0:m, 0:k]


def convert_swizzled_8x4_layout_to_linear(
a_sf_swizzled: torch.Tensor, m, k, block_size
):
m_tiles = (m + 8 - 1) // 8
f = block_size * 4
k_tiles = (k + f - 1) // f
tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 8, 4))
tmp = torch.permute(tmp, (0, 1, 3, 2, 4))
out = tmp.reshape(m_tiles * 8, k_tiles * f // block_size)
return out[0:m, 0:k]


def dequantize_nvfp4_to_dtype(
tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16
tensor_fp4,
tensor_sf,
global_scale,
dtype,
device,
block_size=16,
is_sf_128x4_layout=True,
):
"""Dequantize the fp4 tensor back to high precision."""
# Two fp4 values are packed into one uint8.
Expand All @@ -34,7 +52,11 @@ def dequantize_nvfp4_to_dtype(
tensor_f32 = break_fp4_bytes(tensor_fp4, dtype)
tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size)
tensor_sf = tensor_sf.view(torch.float8_e4m3fn)
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
if is_sf_128x4_layout:
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
else:
tensor_sf = convert_swizzled_8x4_layout_to_linear(tensor_sf, m, k, block_size)

tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale

# scale the tensor
Expand Down
34 changes: 26 additions & 8 deletions tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@

from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm
from vllm.utils.flashinfer import (
flashinfer_scaled_fp4_mm,
)
from vllm.utils.torch_utils import set_random_seed

if not current_platform.has_device_capability(100):
Expand All @@ -22,8 +24,14 @@

DTYPES = [torch.float16, torch.bfloat16]
# m, n, k
SHAPES = [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)]
PAD_SHAPES = [(150, 128, 64), (128, 128, 96)]
SHAPES = [
(128, 128, 64),
(128, 128, 128),
(256, 128, 64),
(128, 256, 128),
(1, 128, 128),
]
PAD_SHAPES = [(150, 128, 64), (128, 128, 96), (2, 128, 64), (3, 128, 96)]
SHAPES.extend(PAD_SHAPES)

SEEDS = [42]
Expand All @@ -42,12 +50,19 @@ def get_ref_results(
dtype,
block_size,
device,
is_sf_128x4_layout,
):
_, m_k = a_fp4.shape
_, n_k = b_fp4.shape
assert m_k == n_k
a_in_dtype = dequantize_nvfp4_to_dtype(
a_fp4, a_sf, a_global_scale, dtype=dtype, device=device, block_size=block_size
a_fp4,
a_sf,
a_global_scale,
dtype=dtype,
device=device,
block_size=block_size,
is_sf_128x4_layout=is_sf_128x4_layout,
)
b_in_dtype = dequantize_nvfp4_to_dtype(
b_fp4, b_sf, b_global_scale, dtype=dtype, device=device, block_size=block_size
Expand All @@ -70,7 +85,7 @@ def test_flashinfer_nvfp4_gemm(
backend: str,
autotune: bool,
) -> None:
if backend == "trtllm" and dtype == torch.float16:
if "trtllm" in backend and dtype == torch.float16:
pytest.skip("Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations")

set_random_seed(seed)
Expand All @@ -87,11 +102,14 @@ def test_flashinfer_nvfp4_gemm(
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1)
).to(torch.float32)
alpha = 1.0 / (a_global_scale * b_global_scale)

# ops.scaled_fp4_quant returns swizzled scales, while weights
# from checkpoints are in linear scales.
# So instead of needing to swizzle for cutlass as in modelopt.py,
# we need to unswizzle for trtllm here.
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale)
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale, backend)
is_sf_128x4_layout = not (backend == "trtllm" and m <= 32)

b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(b_dtype, b_global_scale)

# get_ref_results unswizzles the scales internally.
Expand All @@ -107,14 +125,14 @@ def test_flashinfer_nvfp4_gemm(
dtype,
block_size,
device,
is_sf_128x4_layout,
)

import flashinfer

if backend == "trtllm":
if "trtllm" in backend:
epilogue_tile_m = 128
b_fp4 = flashinfer.shuffle_matrix_a(b_fp4.view(torch.uint8), epilogue_tile_m)

b_scale_interleaved = convert_swizzled_to_linear(
b_scale_interleaved, n, k, block_size
)
Expand Down
26 changes: 26 additions & 0 deletions tests/models/quantization/test_nvfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams

from vllm.platforms import current_platform

os.environ["TOKENIZERS_PARALLELISM"] = "true"

MAX_MODEL_LEN = 1024
Expand Down Expand Up @@ -83,3 +85,27 @@ def test_models(example_prompts, model_name) -> None:
assert expected_str == generated_str, (
f"Test{i}:\nExpected: {expected_str!r}\nvLLM: {generated_str!r}"
)


EAGER = [True, False]


@pytest.mark.skipif(
not current_platform.has_device_capability(100),
reason="modelopt_fp4 is not supported on this GPU type.",
)
@pytest.mark.parametrize("model", ["nvidia/Llama-3.1-8B-Instruct-NVFP4"])
@pytest.mark.parametrize("eager", EAGER)
@pytest.mark.parametrize(
"backend",
[
"flashinfer-cudnn",
"flashinfer-trtllm", # the small seq_len ensures trtllm_8x4_layout backend is used
"flashinfer-cutlass",
],
)
def test_nvfp4(vllm_runner, model, eager, backend, monkeypatch):
monkeypatch.setenv("VLLM_NVFP4_GEMM_BACKEND", backend)
with vllm_runner(model, enforce_eager=eager) as llm:
output = llm.generate_greedy(["1 2 3 4 5"], max_tokens=2)
assert output[0][1] == "1 2 3 4 5 6"
46 changes: 30 additions & 16 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType
from vllm.utils.flashinfer import (
flashinfer_quant_nvfp4_8x4_sf_layout,
)

logger = init_logger(__name__)

Expand Down Expand Up @@ -1563,7 +1566,9 @@ def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor:

# fp4
def scaled_fp4_quant(
input: torch.Tensor, input_global_scale: torch.Tensor
input: torch.Tensor,
input_global_scale: torch.Tensor,
backend: str = "none",
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP4 and return quantized tensor and scale.
Expand All @@ -1577,6 +1582,7 @@ def scaled_fp4_quant(
Args:
input: The input tensor to be quantized to FP4
input_global_scale: A scalar scaling factor for the entire tensor.
use_8x4_sf_layout: Whether to use the 8x4 or 128x4 layout for the scaling

Returns:
tuple[torch.Tensor, torch.Tensor]: The output tensor in FP4 but every
Expand All @@ -1596,23 +1602,31 @@ def scaled_fp4_quant(
f"input.dtype needs to be fp16 or bf16 but got {input.dtype}."
)

# Two fp4 values will be packed into an uint8.
output = torch.empty((m, n // 2), device=device, dtype=torch.uint8)
use_8x4_sf_layout = True if "trtllm" in backend and m <= 32 else False # noqa: SIM210

# We use the rounded values to store the swizzled values. Due to the
# requirement of the Tensor Core, the minimum tile is 128x4 for the scales.
# So, we first pad the scales to multiples of 128 and 4. Then, the scales
# (in float8_e4m3fn) are packed into an int32 for every 4 values. More:
# https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x
round_up = lambda x, y: (x + y - 1) // y * y
rounded_m = round_up(m, 128)
scale_n = n // block_size
rounded_n = round_up(scale_n, 4)
output_scale = torch.empty(
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
)
if use_8x4_sf_layout:
output, output_scale = flashinfer_quant_nvfp4_8x4_sf_layout(
input, input_global_scale
)
else:
# Two fp4 values will be packed into an uint8.
output = torch.empty((m, n // 2), device=device, dtype=torch.uint8)

# We use the rounded values to store the swizzled values. Due to the
# requirement of the Tensor Core, the minimum tile is 128x4 for the scales.
# So, we first pad the scales to multiples of 128 and 4. Then, the scales
# (in float8_e4m3fn) are packed into an int32 for every 4 values. More:
# https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x
round_up = lambda x, y: (x + y - 1) // y * y
rounded_m = round_up(m, 128)
scale_n = n // block_size
rounded_n = round_up(scale_n, 4)
output_scale = torch.empty(
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
)

torch.ops._C.scaled_fp4_quant(output, input, output_scale, input_global_scale)

torch.ops._C.scaled_fp4_quant(output, input, output_scale, input_global_scale)
output_scale = output_scale.view(torch.float8_e4m3fn)
return output, output_scale

Expand Down
7 changes: 6 additions & 1 deletion vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1444,7 +1444,12 @@ def get_vllm_port() -> int | None:
"VLLM_NVFP4_GEMM_BACKEND": env_with_choices(
"VLLM_NVFP4_GEMM_BACKEND",
None,
["flashinfer-cudnn", "flashinfer-trtllm", "flashinfer-cutlass", "cutlass"],
[
"flashinfer-cudnn",
"flashinfer-trtllm",
"flashinfer-cutlass",
"cutlass",
],
),
# Controls garbage collection during CUDA graph capture.
# If set to 0 (default), enables GC freezing to speed up capture time.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
ModelWeightParameter,
PerTensorScaleParameter,
)
from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm, has_flashinfer
from vllm.utils.flashinfer import (
flashinfer_scaled_fp4_mm,
has_flashinfer,
)

logger = init_logger(__name__)

Expand Down Expand Up @@ -187,7 +190,9 @@ def apply_weights(
output_shape = [*x.shape[:-1], layer.weight_packed.shape[0]]

# quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale)
x_fp4, x_blockscale = scaled_fp4_quant(
x, layer.input_global_scale, self.backend
)

mm_args = (
x_fp4,
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/quantization/modelopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,7 +1291,7 @@ def apply(
output_shape = [x.shape[0], layer.weight.shape[0]]

# quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv)
x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv, self.backend)

# validate dtypes of quantized input, input block scale,
# weight and weight_blockscale
Expand Down
Loading