Skip to content
192 changes: 154 additions & 38 deletions benchmarks/kernels/bench_per_token_quant_fp8.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
# SPDX-License-Identifier: Apache-2.0

Check warning on line 1 in benchmarks/kernels/bench_per_token_quant_fp8.py

View workflow job for this annotation

GitHub Actions / bc_lint

Function cuda_per_token_quant_fp8: function deleted

Check warning on line 1 in benchmarks/kernels/bench_per_token_quant_fp8.py

View workflow job for this annotation

GitHub Actions / bc_lint

Function cuda_per_token_quant_fp8: function deleted
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from typing import Callable
from unittest.mock import patch

import pandas as pd
import torch

from vllm import _custom_ops as ops
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.triton_utils import triton
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser


def with_triton_mode(fn):
"""Temporarily force the Triton fallback path"""

def wrapped(*args, **kwargs):
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
return fn(*args, **kwargs)

return wrapped


# TODO(luka): use standalone_compile utility
Expand All @@ -21,78 +32,183 @@
return inner


torch._dynamo.config.recompile_limit = 8888
compilation_config = CompilationConfig(custom_ops=["none"])
with set_current_vllm_config(VllmConfig(compilation_config=compilation_config)):
torch_per_token_quant_fp8 = torch.compile(
QuantFP8(False, GroupShape.PER_TOKEN),
fullgraph=True,
dynamic=False, # recompile for different shapes
)
def bench_compile(fn: Callable):
# recompile for different shapes
fwd = torch.compile(fn, fullgraph=True, dynamic=False)
Comment on lines +36 to +37
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why aren't we compiling with dynamic=True? I don't think we should be targeting shape specialization since we won't use that in practice

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In practice we specialize on all shapes except the first dim (num_tokens). The with_dyn_arg marks that shape as dynamic to fully simulate vLLM usage 👍


# First dim is explicitly dynamic to simulate vLLM usage
torch_per_token_quant_fp8 = with_dyn_arg(torch_per_token_quant_fp8, 0, 0)
return with_dyn_arg(fwd, 0, 0)


def cuda_per_token_quant_fp8(
input: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
return ops.scaled_fp8_quant(input)
torch._dynamo.config.recompile_limit = 8888


def calculate_diff(batch_size: int, seq_len: int):
"""Calculate difference between Triton and CUDA implementations."""
def calculate_diff(

Check warning on line 46 in benchmarks/kernels/bench_per_token_quant_fp8.py

View workflow job for this annotation

GitHub Actions / bc_lint

Function calculate_diff: seq_len was renamed to hidden_size

Check warning on line 46 in benchmarks/kernels/bench_per_token_quant_fp8.py

View workflow job for this annotation

GitHub Actions / bc_lint

Function calculate_diff: seq_len was renamed to hidden_size
batch_size: int,
hidden_size: int,
group_shape: GroupShape,
dtype: torch.dtype,
):
"""Calculate the difference between Inductor and CUDA implementations."""
device = torch.device("cuda")
x = torch.rand((batch_size * seq_len, 4096), dtype=torch.float16, device=device)
x = torch.rand((batch_size * hidden_size, 4096), dtype=dtype, device=device)

quant_fp8 = QuantFP8(False, group_shape, column_major_scales=False)

torch_out, torch_scale = torch_per_token_quant_fp8(x)
cuda_out, cuda_scale = cuda_per_token_quant_fp8(x)
torch_out, torch_scale = bench_compile(quant_fp8.forward_native)(x)
torch_eager_out, torch_eager_scale = quant_fp8.forward_native(x)
cuda_out, cuda_scale = quant_fp8.forward_cuda(x)

if torch.allclose(
cuda_out.to(torch.float32), torch_out.to(torch.float32), rtol=1e-3, atol=1e-5
) and torch.allclose(cuda_scale, torch_scale, rtol=1e-3, atol=1e-5):
out_allclose = lambda o1, o2: torch.allclose(
o1.to(torch.float32),
o2.to(torch.float32),
rtol=1e-3,
atol=1e-5,
)
scale_allclose = lambda s1, s2: torch.allclose(s1, s2, rtol=1e-3, atol=1e-5)

if (
out_allclose(cuda_out, torch_out)
and scale_allclose(cuda_scale, torch_scale)
and out_allclose(cuda_out, torch_eager_out)
and scale_allclose(cuda_scale, torch_eager_scale)
):
print("✅ All implementations match")
else:
print("❌ Implementations differ")


batch_size_range = [1, 16, 32, 64, 128]
seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]
hidden_sizes = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]
batch_sizes = [1, 16, 32, 64, 128]
group_shapes = [
GroupShape.PER_TENSOR,
GroupShape.PER_TOKEN,
GroupShape(1, 64),
GroupShape(1, 128),
]
column_major_scales = [True, False]

config_gen = itertools.product(
group_shapes,
column_major_scales,
batch_sizes,
hidden_sizes,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make some of these args? It takes a really long time by default to run all of this

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can that be a follow-up it requires reworking the structure a lot (because currently this is passed to the function annotation).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed. Now you should be able to use something like this:

python3 benchmarks/kernels/bench_per_token_quant_fp8.py --hidden-sizes 1024 2048 4096 --batch-sizes 32 --group-sizes 128 --no-column-major


configs = list(itertools.product(batch_size_range, seq_len_range))
# filter out column-major scales for non-group, reverse order
configs = list(c[::-1] for c in config_gen if (c[0].is_per_group() or not c[1]))


@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "seq_len"],
x_names=["hidden_size", "batch_size", "col_major", "group_shape"],
x_vals=configs,
line_arg="provider",
line_vals=["torch", "cuda"],
line_names=["Torch", "CUDA"],
styles=[("blue", "-"), ("green", "-")],
line_vals=["torch", "cuda", "triton"],
line_names=["Torch (Compiled)", "CUDA", "Triton"],
styles=[("blue", "-"), ("green", "-"), ("black", "-")],
ylabel="us",
plot_name="per-token-dynamic-quant-fp8-performance",
plot_name="QuantFP8 performance",
args={},
)
)
def benchmark_quantization(batch_size, seq_len, provider):
dtype = torch.float16
def benchmark_quantization(

Check warning on line 115 in benchmarks/kernels/bench_per_token_quant_fp8.py

View workflow job for this annotation

GitHub Actions / bc_lint

Function benchmark_quantization: seq_len was renamed to hidden_size

Check warning on line 115 in benchmarks/kernels/bench_per_token_quant_fp8.py

View workflow job for this annotation

GitHub Actions / bc_lint

Function benchmark_quantization: seq_len was renamed to hidden_size
batch_size,
hidden_size,
provider,
group_shape: GroupShape,

Check warning on line 119 in benchmarks/kernels/bench_per_token_quant_fp8.py

View workflow job for this annotation

GitHub Actions / bc_lint

Function benchmark_quantization: group_shape was added and is now required

Check warning on line 119 in benchmarks/kernels/bench_per_token_quant_fp8.py

View workflow job for this annotation

GitHub Actions / bc_lint

Function benchmark_quantization: group_shape was added and is now required
col_major: bool,
dtype: torch.dtype,
):
device = torch.device("cuda")

x = torch.randn(batch_size * seq_len, 4096, device=device, dtype=dtype)
x = torch.randn(batch_size * hidden_size, 4096, device=device, dtype=dtype)

quantiles = [0.5, 0.2, 0.8]
quant_fp8 = QuantFP8(False, group_shape, column_major_scales=col_major)

if provider == "torch":
fn = lambda: torch_per_token_quant_fp8(x.clone())
fn = lambda: bench_compile(quant_fp8.forward_native)(x.clone())
elif provider == "cuda":
fn = lambda: cuda_per_token_quant_fp8(x.clone())
fn = lambda: quant_fp8.forward_cuda(x.clone())
elif provider == "triton":
if not group_shape.is_per_group():
# Triton only supported for per-group
return 0, 0, 0

fn = lambda: with_triton_mode(quant_fp8.forward_cuda)(x.clone())

ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)

return 1000 * ms, 1000 * max_ms, 1000 * min_ms


# TODO(luka) extract to utils
def compute_geomean_speedups(
df: pd.DataFrame,
baseline_col: str,
speedup_cols: list[str],
groupby_cols: list[str] | None = None,
) -> pd.DataFrame:
"""
Compute geometric mean speedups over a baseline column.

Args:
df: Input dataframe
baseline_col: Column to use as baseline
speedup_cols: Columns to compute speedups for
groupby_cols: Columns to group by. If None, compute over entire df.

Returns:
pd.DataFrame with geometric mean speedups
"""
from scipy.stats import gmean

def geo_speedup(group: pd.DataFrame) -> pd.Series:
ratios = {
col: (group[baseline_col] / group[col]).values for col in speedup_cols
}
return pd.Series({col: gmean(vals) for col, vals in ratios.items()})

if groupby_cols is None:
result = geo_speedup(df).to_frame().T
else:
result = df.groupby(groupby_cols).apply(geo_speedup).reset_index()

return result


if __name__ == "__main__":
calculate_diff(batch_size=4, seq_len=4096)
benchmark_quantization.run(print_data=True)
parser = FlexibleArgumentParser(
description="Benchmark the various implementations of QuantFP8 (dynamic-only)"
)
parser.add_argument("-c", "--check", action="store_true")
parser.add_argument(
"--dtype", type=str, choices=["half", "bfloat16", "float"], default="half"
)

args = parser.parse_args()
assert args

dtype = STR_DTYPE_TO_TORCH_DTYPE[args.dtype]

if args.check:
for group_shape in group_shapes:
group_size = group_shape[1]
print(f"{group_size=}")
calculate_diff(
batch_size=4, hidden_size=4096, group_shape=group_shape, dtype=dtype
)

df = benchmark_quantization.run(print_data=True, dtype=dtype, return_df=True)

# Print geomean speedups
geo_table_grouped = compute_geomean_speedups(
df,
baseline_col="Torch (Compiled)",
speedup_cols=["CUDA", "Triton"],
groupby_cols=["col_major", "group_shape"],
)

print("Speedup over Torch (Compiled)")
print(geo_table_grouped.to_string(index=False))
150 changes: 150 additions & 0 deletions tests/kernels/quantization/test_fp8_quant_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for QuantFP8 Group Quantization implementation."""

import pytest
import torch

from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.platforms import current_platform


@pytest.mark.parametrize(
"batch_size,hidden_dim,group_size",
[
(16, 256, 32), # Small
(64, 1024, 64), # Medium
(128, 2048, 128), # Large
(8, 513, 64), # Non-divisible (native only)
])
@pytest.mark.parametrize("seed", [42])
@torch.inference_mode()
def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
group_size: int, seed: int) -> None:
"""Test QuantFP8 group quantization with various configurations.

Tests both CUDA and native implementations, column-major scales,
and verifies consistency between implementations.
"""
current_platform.seed_everything(seed)

x = torch.randn(
(batch_size, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8
expected_num_groups = (hidden_dim + group_size - 1) // group_size
is_divisible = hidden_dim % group_size == 0

group_shape = GroupShape(1, group_size)
quant_op = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=False)

# 1. Test native implementation (always available)
x_quant_native, scales_native = quant_op.forward_native(x.clone())
assert x_quant_native.shape == x.shape
assert scales_native.shape == (batch_size, expected_num_groups)

# 2. Test column-major scales configuration
quant_op_col = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=True)
_, scales_col = quant_op_col.forward_native(x.clone())
assert scales_col.shape == (expected_num_groups, batch_size)

# 3. Test CUDA implementation (only for divisible dimensions)
if is_divisible:
x_quant_cuda, scales_cuda = quant_op.forward_cuda(x.clone())
assert x_quant_cuda.shape == x.shape
assert scales_cuda.shape == (batch_size, expected_num_groups)

# Verify CUDA/native consistency
assert torch.allclose(scales_cuda, scales_native, rtol=1e-9, atol=1e-8)

# Quantized values should mostly match
diff_count = (x_quant_cuda != x_quant_native).sum().item()
diff_ratio = diff_count / x_quant_cuda.numel()
assert diff_ratio < 0.002, f"Too many differences: {diff_ratio:.4%}"


@pytest.mark.parametrize("seed", [42])
@torch.inference_mode()
def test_quantfp8_group_multidimensional(seed: int) -> None:
current_platform.seed_everything(seed)

group_size = 64

# Test with 3D input
batch1, batch2, hidden_dim = 4, 8, 512
x_3d = torch.randn(
(batch1, batch2, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8

group_shape = GroupShape(1, group_size)
quant_op = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=False)

x_quant, scales = quant_op.forward_native(x_3d.clone())
assert x_quant.shape == x_3d.shape
assert scales.shape == (batch1, batch2, hidden_dim // group_size)

# Test column_major_scales with multi-dim
quant_op_col = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=True)
_, scales_col = quant_op_col.forward_native(x_3d.clone())
assert scales_col.shape == (batch1, hidden_dim // group_size, batch2)

# Test with 4D input
batch1, batch2, batch3, hidden_dim = 2, 3, 4, 256
x_4d = torch.randn((batch1, batch2, batch3, hidden_dim),
dtype=torch.bfloat16,
device="cuda") * 8

x_quant_4d, scales_4d = quant_op.forward_native(x_4d.clone())
assert x_quant_4d.shape == x_4d.shape
assert scales_4d.shape == (batch1, batch2, batch3,
hidden_dim // group_size)

_, scales_4d_col = quant_op_col.forward_native(x_4d.clone())
assert scales_4d_col.shape == (batch1, batch2, hidden_dim // group_size,
batch3)


@pytest.mark.parametrize("seed", [42])
@torch.inference_mode()
def test_quantfp8_group_edge_cases(seed: int) -> None:
current_platform.seed_everything(seed)

batch_size = 16
group_size = 64

# Test with single group (group_size >= hidden_dim)
x_small = torch.randn(
(batch_size, 32), dtype=torch.bfloat16, device="cuda") * 8
group_shape = GroupShape(1, group_size)
quant_op = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=False)

x_quant_small, scales_small = quant_op.forward_native(x_small.clone())
assert x_quant_small.shape == x_small.shape
assert scales_small.shape == (batch_size, 1)

# Test with zero inputs
x_zero = torch.zeros((batch_size, 256),
dtype=torch.bfloat16,
device="cuda")
x_quant_zero, scales_zero = quant_op.forward_native(x_zero.clone())
assert x_quant_zero.shape == x_zero.shape
assert (scales_zero > 0).all(), "Scales should be clamped to minimum"

# Test very large values
x_large = torch.full((batch_size, 256),
1000.0,
dtype=torch.bfloat16,
device="cuda")
x_quant_large, scales_large = quant_op.forward_native(x_large.clone())
assert x_quant_large.shape == x_large.shape
# FP8 max is typically 448 or 224, so scales should be > 1
assert (scales_large > 1.0).all(), "Large values should have scales > 1"
Loading
Loading