Skip to content
148 changes: 148 additions & 0 deletions benchmarks/kernels/benchmark_quantfp8_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
#!/usr/bin/env python
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Benchmark for QuantFP8 Group Quantization implementation."""

import argparse

import torch

from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform


def _time_cuda(
fn,
warmup_iters: int,
bench_iters: int,
) -> float:
# warmup
for _ in range(warmup_iters):
fn()
torch.cuda.synchronize()

start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

start.record()
for _ in range(bench_iters):
fn()
end.record()
torch.cuda.synchronize()

return start.elapsed_time(end) / bench_iters # ms/iter


def run_benchmark(
shape: tuple[int, int],
group_size: int,
column_major: bool,
warmup_iters: int,
bench_iters: int,
) -> None:
"""Benchmark QuantFP8 with group quantization using different backends."""
num_tokens, hidden_dim = shape

device = torch.device("cuda")
torch.manual_seed(42)
x = torch.randn(num_tokens, hidden_dim, device=device, dtype=torch.bfloat16) * 8

group_shape = GroupShape(1, group_size)
quant_op = QuantFP8(
static=False, group_shape=group_shape, column_major_scales=column_major
)

def cuda_impl():
return quant_op.forward_cuda(x.clone())

def native_impl():
return quant_op.forward_native(x.clone())

cuda_ms = _time_cuda(cuda_impl, warmup_iters, bench_iters)
native_ms = _time_cuda(native_impl, warmup_iters, bench_iters)

speedup = cuda_ms / native_ms if native_ms else 0

cfg_desc = f"shape={shape} gs={group_size:<3} col_major={column_major}"
print(f"{cfg_desc:45} | {cuda_ms:7.3f} | {native_ms:7.3f} | {speedup:6.2f}x")


def parse_args():
parser = argparse.ArgumentParser(
description="Benchmark QuantFP8 group quantization implementation"
)
parser.add_argument(
"--warmup-iters", type=int, default=10, help="Number of warmup iterations"
)
parser.add_argument(
"--bench-iters", type=int, default=100, help="Number of benchmark iterations"
)
parser.add_argument(
"--shapes",
type=str,
default="32,128;64,256;16,512;128,1024;256,2048",
help="Shapes to benchmark as 'tokens,hidden;...' (default: multiple shapes)",
)
parser.add_argument(
"--group-sizes",
type=str,
default="64,128",
help="Group sizes to benchmark (comma-separated)",
)
parser.add_argument(
"--no-column-major",
action="store_true",
help="Skip column-major scale benchmarks",
)
return parser.parse_args()


def main():
if not current_platform.is_cuda():
raise RuntimeError("CUDA device is required to run this benchmark.")

args = parse_args()

shapes = []
for shape_str in args.shapes.split(";"):
tokens, hidden = map(int, shape_str.split(","))
shapes.append((tokens, hidden))

group_sizes = list(map(int, args.group_sizes.split(",")))

print("\n" + "=" * 80)
print("QuantFP8 Group Quantization Benchmark (CUDA kernel vs PyTorch native)")
print("=" * 80)
print(f"Device: {torch.cuda.get_device_name()}")
print(f"Warmup iterations: {args.warmup_iters}")
print(f"Benchmark iterations: {args.bench_iters}")
print("=" * 80)

print(f"{'Configuration':45} | {'CUDA':^9} | {'Native':^9} | {'Speedup':^8}")
print("-" * 80)

for shape in shapes:
for gs in group_sizes:
run_benchmark(
shape,
gs,
column_major=False,
warmup_iters=args.warmup_iters,
bench_iters=args.bench_iters,
)

if not args.no_column_major:
run_benchmark(
shape,
gs,
column_major=True,
warmup_iters=args.warmup_iters,
bench_iters=args.bench_iters,
)

print("=" * 80)


if __name__ == "__main__":
main()
206 changes: 206 additions & 0 deletions tests/kernels/quantization/test_fp8_quant_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for QuantFP8 Group Quantization implementation."""

import pytest
import torch

from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.platforms import current_platform


@pytest.mark.parametrize("batch_size", [16, 32])
@pytest.mark.parametrize("hidden_dim",
[256, 512, 513]) # Include non-divisible
@pytest.mark.parametrize("group_size", [32, 64, 128])
@pytest.mark.parametrize("seed", [42])
@torch.inference_mode()
def test_quantfp8_group_basic(batch_size: int, hidden_dim: int,
group_size: int, seed: int) -> None:
current_platform.seed_everything(seed)

x = torch.randn(
(batch_size, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8

# Create QuantFP8 with group quantization
group_shape = GroupShape(1, group_size)
quant_op = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=False)

expected_num_groups = (hidden_dim + group_size - 1) // group_size

# Test CUDA implementation (only supports divisible dimensions)
if hidden_dim % group_size == 0:
x_quant_cuda, scales_cuda = quant_op.forward_cuda(x.clone())
assert x_quant_cuda.shape == x.shape
assert scales_cuda.shape == (batch_size, expected_num_groups)

# Test PyTorch native implementation
x_quant_native, scales_native = quant_op.forward_native(x.clone())
assert x_quant_native.shape == x.shape
assert scales_native.shape == (batch_size, expected_num_groups)

# Test column_major_scales
quant_op_col = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=True)
_, scales_col = quant_op_col.forward_native(x.clone())
assert scales_col.shape == (expected_num_groups, batch_size)


@pytest.mark.parametrize("seed", [42])
@torch.inference_mode()
def test_quantfp8_group_multidimensional(seed: int) -> None:
current_platform.seed_everything(seed)

group_size = 64

# Test with 3D input
batch1, batch2, hidden_dim = 4, 8, 512
x_3d = torch.randn(
(batch1, batch2, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8

group_shape = GroupShape(1, group_size)
quant_op = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=False)

x_quant, scales = quant_op.forward_native(x_3d.clone())
assert x_quant.shape == x_3d.shape
assert scales.shape == (batch1, batch2, hidden_dim // group_size)

# Test column_major_scales with multi-dim
quant_op_col = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=True)
_, scales_col = quant_op_col.forward_native(x_3d.clone())
assert scales_col.shape == (batch1, hidden_dim // group_size, batch2)

# Test with 4D input
batch1, batch2, batch3, hidden_dim = 2, 3, 4, 256
x_4d = torch.randn((batch1, batch2, batch3, hidden_dim),
dtype=torch.bfloat16,
device="cuda") * 8

x_quant_4d, scales_4d = quant_op.forward_native(x_4d.clone())
assert x_quant_4d.shape == x_4d.shape
assert scales_4d.shape == (batch1, batch2, batch3,
hidden_dim // group_size)

_, scales_4d_col = quant_op_col.forward_native(x_4d.clone())
assert scales_4d_col.shape == (batch1, batch2, hidden_dim // group_size,
batch3)


@pytest.mark.parametrize("batch_size", [32])
@pytest.mark.parametrize("hidden_dim", [1024])
@pytest.mark.parametrize("group_size", [128])
@pytest.mark.parametrize("seed", [42])
@torch.inference_mode()
def test_quantfp8_group_cuda_native_consistency(batch_size: int,
hidden_dim: int,
group_size: int,
seed: int) -> None:
"""Compare CUDA and native implementations for consistency."""
current_platform.seed_everything(seed)

x = torch.randn(
(batch_size, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8

group_shape = GroupShape(1, group_size)
quant_op = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=False)

# Run both implementations
x_quant_cuda, scales_cuda = quant_op.forward_cuda(x.clone())
x_quant_native, scales_native = quant_op.forward_native(x.clone())

# Check shapes match
assert x_quant_cuda.shape == x_quant_native.shape
assert scales_cuda.shape == scales_native.shape

# Scales should match
assert torch.allclose(scales_cuda, scales_native, rtol=1e-9, atol=1e-8)

# Quantized values should mostly match, with rare rounding differences
# FP8 rounding at boundaries can differ between CUDA and PyTorch
diff_count = (x_quant_cuda != x_quant_native).sum().item()
diff_ratio = diff_count / x_quant_cuda.numel()
assert diff_ratio < 0.002, f"Too many differences: {diff_ratio:.4%}"


@pytest.mark.parametrize("seed", [42])
@torch.inference_mode()
def test_quantfp8_group_edge_cases(seed: int) -> None:
current_platform.seed_everything(seed)

batch_size = 16
group_size = 64

# Test with single group (group_size >= hidden_dim)
x_small = torch.randn(
(batch_size, 32), dtype=torch.bfloat16, device="cuda") * 8
group_shape = GroupShape(1, group_size)
quant_op = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=False)

x_quant_small, scales_small = quant_op.forward_native(x_small.clone())
assert x_quant_small.shape == x_small.shape
assert scales_small.shape == (batch_size, 1)

# Test with zero inputs
x_zero = torch.zeros((batch_size, 256),
dtype=torch.bfloat16,
device="cuda")
x_quant_zero, scales_zero = quant_op.forward_native(x_zero.clone())
assert x_quant_zero.shape == x_zero.shape
assert (scales_zero > 0).all(), "Scales should be clamped to minimum"

# Test very large values
x_large = torch.full((batch_size, 256),
1000.0,
dtype=torch.bfloat16,
device="cuda")
x_quant_large, scales_large = quant_op.forward_native(x_large.clone())
assert x_quant_large.shape == x_large.shape
# FP8 max is typically 448 or 224, so scales should be > 1
assert (scales_large > 1.0).all(), "Large values should have scales > 1"


@pytest.mark.parametrize(
"batch_size,hidden_dim,group_size",
[
(16, 256, 16), # Small
(64, 1024, 64), # Medium
(128, 2048, 128), # Large
(8, 513, 64), # Non-divisible (native only)
])
@pytest.mark.parametrize("seed", [42])
@torch.inference_mode()
def test_quantfp8_group_various_configs(batch_size: int, hidden_dim: int,
group_size: int, seed: int) -> None:
current_platform.seed_everything(seed)

x = torch.randn(
(batch_size, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8
group_shape = GroupShape(1, group_size)
quant_op = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=False)

expected_num_groups = (hidden_dim + group_size - 1) // group_size

x_quant_native, scales_native = quant_op.forward_native(x.clone())
assert x_quant_native.shape == x.shape
assert scales_native.shape == (batch_size, expected_num_groups)

if hidden_dim % group_size == 0:
x_quant_cuda, scales_cuda = quant_op.forward_cuda(x.clone())
assert x_quant_cuda.shape == x.shape
assert scales_cuda.shape == (batch_size, expected_num_groups)
assert torch.allclose(scales_cuda, scales_native, rtol=1e-9, atol=1e-8)
4 changes: 3 additions & 1 deletion vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP)
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, moe_kernel_quantize_input, per_token_group_quant_fp8)
_resize_cache, moe_kernel_quantize_input)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
calculate_tile_tokens_dim)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
dequant_mxfp4)
from vllm.platforms import current_platform
Expand Down
Loading