-
-
Notifications
You must be signed in to change notification settings - Fork 12.4k
[FP8] Extend per-token-group quantization support to QuantFP8 #24342
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
b0b9d48
74bd084
b50d163
2662be1
4fe4578
100b11c
dd45227
ff0855a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,15 +1,26 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
Check warning on line 1 in benchmarks/kernels/bench_per_token_quant_fp8.py
|
||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| import itertools | ||
| from typing import Callable | ||
| from unittest.mock import patch | ||
|
|
||
| import pandas as pd | ||
| import torch | ||
|
|
||
| from vllm import _custom_ops as ops | ||
| from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config | ||
| from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 | ||
| from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape | ||
| from vllm.triton_utils import triton | ||
| from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser | ||
|
|
||
|
|
||
| def with_triton_mode(fn): | ||
| """Temporarily force the Triton fallback path""" | ||
|
|
||
| def wrapped(*args, **kwargs): | ||
| with patch("vllm.platforms.current_platform.is_cuda", return_value=False): | ||
| return fn(*args, **kwargs) | ||
|
|
||
| return wrapped | ||
|
|
||
|
|
||
| # TODO(luka): use standalone_compile utility | ||
|
|
@@ -21,78 +32,183 @@ | |
| return inner | ||
|
|
||
|
|
||
| torch._dynamo.config.recompile_limit = 8888 | ||
| compilation_config = CompilationConfig(custom_ops=["none"]) | ||
| with set_current_vllm_config(VllmConfig(compilation_config=compilation_config)): | ||
| torch_per_token_quant_fp8 = torch.compile( | ||
| QuantFP8(False, GroupShape.PER_TOKEN), | ||
| fullgraph=True, | ||
| dynamic=False, # recompile for different shapes | ||
| ) | ||
| def bench_compile(fn: Callable): | ||
| # recompile for different shapes | ||
| fwd = torch.compile(fn, fullgraph=True, dynamic=False) | ||
|
|
||
| # First dim is explicitly dynamic to simulate vLLM usage | ||
| torch_per_token_quant_fp8 = with_dyn_arg(torch_per_token_quant_fp8, 0, 0) | ||
| return with_dyn_arg(fwd, 0, 0) | ||
|
|
||
|
|
||
| def cuda_per_token_quant_fp8( | ||
| input: torch.Tensor, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| return ops.scaled_fp8_quant(input) | ||
| torch._dynamo.config.recompile_limit = 8888 | ||
|
|
||
|
|
||
| def calculate_diff(batch_size: int, seq_len: int): | ||
| """Calculate difference between Triton and CUDA implementations.""" | ||
| def calculate_diff( | ||
|
Check warning on line 46 in benchmarks/kernels/bench_per_token_quant_fp8.py
|
||
| batch_size: int, | ||
| hidden_size: int, | ||
| group_shape: GroupShape, | ||
| dtype: torch.dtype, | ||
| ): | ||
| """Calculate the difference between Inductor and CUDA implementations.""" | ||
| device = torch.device("cuda") | ||
| x = torch.rand((batch_size * seq_len, 4096), dtype=torch.float16, device=device) | ||
| x = torch.rand((batch_size * hidden_size, 4096), dtype=dtype, device=device) | ||
|
|
||
| quant_fp8 = QuantFP8(False, group_shape, column_major_scales=False) | ||
|
|
||
| torch_out, torch_scale = torch_per_token_quant_fp8(x) | ||
| cuda_out, cuda_scale = cuda_per_token_quant_fp8(x) | ||
| torch_out, torch_scale = bench_compile(quant_fp8.forward_native)(x) | ||
| torch_eager_out, torch_eager_scale = quant_fp8.forward_native(x) | ||
| cuda_out, cuda_scale = quant_fp8.forward_cuda(x) | ||
|
|
||
| if torch.allclose( | ||
| cuda_out.to(torch.float32), torch_out.to(torch.float32), rtol=1e-3, atol=1e-5 | ||
| ) and torch.allclose(cuda_scale, torch_scale, rtol=1e-3, atol=1e-5): | ||
| out_allclose = lambda o1, o2: torch.allclose( | ||
| o1.to(torch.float32), | ||
| o2.to(torch.float32), | ||
| rtol=1e-3, | ||
| atol=1e-5, | ||
| ) | ||
| scale_allclose = lambda s1, s2: torch.allclose(s1, s2, rtol=1e-3, atol=1e-5) | ||
|
|
||
| if ( | ||
| out_allclose(cuda_out, torch_out) | ||
| and scale_allclose(cuda_scale, torch_scale) | ||
| and out_allclose(cuda_out, torch_eager_out) | ||
| and scale_allclose(cuda_scale, torch_eager_scale) | ||
| ): | ||
| print("✅ All implementations match") | ||
| else: | ||
| print("❌ Implementations differ") | ||
|
|
||
|
|
||
| batch_size_range = [1, 16, 32, 64, 128] | ||
| seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096] | ||
| hidden_sizes = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096] | ||
| batch_sizes = [1, 16, 32, 64, 128] | ||
| group_shapes = [ | ||
| GroupShape.PER_TENSOR, | ||
| GroupShape.PER_TOKEN, | ||
| GroupShape(1, 64), | ||
| GroupShape(1, 128), | ||
| ] | ||
| column_major_scales = [True, False] | ||
|
|
||
| config_gen = itertools.product( | ||
| group_shapes, | ||
| column_major_scales, | ||
| batch_sizes, | ||
| hidden_sizes, | ||
| ) | ||
|
||
|
|
||
| configs = list(itertools.product(batch_size_range, seq_len_range)) | ||
| # filter out column-major scales for non-group, reverse order | ||
| configs = list(c[::-1] for c in config_gen if (c[0].is_per_group() or not c[1])) | ||
|
|
||
|
|
||
| @triton.testing.perf_report( | ||
| triton.testing.Benchmark( | ||
| x_names=["batch_size", "seq_len"], | ||
| x_names=["hidden_size", "batch_size", "col_major", "group_shape"], | ||
| x_vals=configs, | ||
| line_arg="provider", | ||
| line_vals=["torch", "cuda"], | ||
| line_names=["Torch", "CUDA"], | ||
| styles=[("blue", "-"), ("green", "-")], | ||
| line_vals=["torch", "cuda", "triton"], | ||
| line_names=["Torch (Compiled)", "CUDA", "Triton"], | ||
| styles=[("blue", "-"), ("green", "-"), ("black", "-")], | ||
| ylabel="us", | ||
| plot_name="per-token-dynamic-quant-fp8-performance", | ||
| plot_name="QuantFP8 performance", | ||
| args={}, | ||
| ) | ||
| ) | ||
| def benchmark_quantization(batch_size, seq_len, provider): | ||
| dtype = torch.float16 | ||
| def benchmark_quantization( | ||
|
Check warning on line 115 in benchmarks/kernels/bench_per_token_quant_fp8.py
|
||
| batch_size, | ||
| hidden_size, | ||
| provider, | ||
| group_shape: GroupShape, | ||
|
Check warning on line 119 in benchmarks/kernels/bench_per_token_quant_fp8.py
|
||
| col_major: bool, | ||
| dtype: torch.dtype, | ||
| ): | ||
| device = torch.device("cuda") | ||
|
|
||
| x = torch.randn(batch_size * seq_len, 4096, device=device, dtype=dtype) | ||
| x = torch.randn(batch_size * hidden_size, 4096, device=device, dtype=dtype) | ||
|
|
||
| quantiles = [0.5, 0.2, 0.8] | ||
| quant_fp8 = QuantFP8(False, group_shape, column_major_scales=col_major) | ||
|
|
||
| if provider == "torch": | ||
| fn = lambda: torch_per_token_quant_fp8(x.clone()) | ||
| fn = lambda: bench_compile(quant_fp8.forward_native)(x.clone()) | ||
| elif provider == "cuda": | ||
| fn = lambda: cuda_per_token_quant_fp8(x.clone()) | ||
| fn = lambda: quant_fp8.forward_cuda(x.clone()) | ||
| elif provider == "triton": | ||
| if not group_shape.is_per_group(): | ||
| # Triton only supported for per-group | ||
| return 0, 0, 0 | ||
|
|
||
| fn = lambda: with_triton_mode(quant_fp8.forward_cuda)(x.clone()) | ||
|
|
||
| ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles) | ||
|
|
||
| return 1000 * ms, 1000 * max_ms, 1000 * min_ms | ||
|
|
||
|
|
||
| # TODO(luka) extract to utils | ||
| def compute_geomean_speedups( | ||
| df: pd.DataFrame, | ||
| baseline_col: str, | ||
| speedup_cols: list[str], | ||
| groupby_cols: list[str] | None = None, | ||
| ) -> pd.DataFrame: | ||
| """ | ||
| Compute geometric mean speedups over a baseline column. | ||
|
|
||
| Args: | ||
| df: Input dataframe | ||
| baseline_col: Column to use as baseline | ||
| speedup_cols: Columns to compute speedups for | ||
| groupby_cols: Columns to group by. If None, compute over entire df. | ||
|
|
||
| Returns: | ||
| pd.DataFrame with geometric mean speedups | ||
| """ | ||
| from scipy.stats import gmean | ||
|
|
||
| def geo_speedup(group: pd.DataFrame) -> pd.Series: | ||
| ratios = { | ||
| col: (group[baseline_col] / group[col]).values for col in speedup_cols | ||
| } | ||
| return pd.Series({col: gmean(vals) for col, vals in ratios.items()}) | ||
|
|
||
| if groupby_cols is None: | ||
| result = geo_speedup(df).to_frame().T | ||
| else: | ||
| result = df.groupby(groupby_cols).apply(geo_speedup).reset_index() | ||
|
|
||
| return result | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| calculate_diff(batch_size=4, seq_len=4096) | ||
| benchmark_quantization.run(print_data=True) | ||
| parser = FlexibleArgumentParser( | ||
| description="Benchmark the various implementations of QuantFP8 (dynamic-only)" | ||
| ) | ||
| parser.add_argument("-c", "--check", action="store_true") | ||
| parser.add_argument( | ||
| "--dtype", type=str, choices=["half", "bfloat16", "float"], default="half" | ||
| ) | ||
|
|
||
| args = parser.parse_args() | ||
| assert args | ||
|
|
||
| dtype = STR_DTYPE_TO_TORCH_DTYPE[args.dtype] | ||
|
|
||
| if args.check: | ||
| for group_shape in group_shapes: | ||
| group_size = group_shape[1] | ||
| print(f"{group_size=}") | ||
| calculate_diff( | ||
| batch_size=4, hidden_size=4096, group_shape=group_shape, dtype=dtype | ||
| ) | ||
|
|
||
| df = benchmark_quantization.run(print_data=True, dtype=dtype, return_df=True) | ||
|
|
||
| # Print geomean speedups | ||
| geo_table_grouped = compute_geomean_speedups( | ||
| df, | ||
| baseline_col="Torch (Compiled)", | ||
| speedup_cols=["CUDA", "Triton"], | ||
| groupby_cols=["col_major", "group_shape"], | ||
| ) | ||
|
|
||
| print("Speedup over Torch (Compiled)") | ||
| print(geo_table_grouped.to_string(index=False)) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,150 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| """Tests for QuantFP8 Group Quantization implementation.""" | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 | ||
| from vllm.model_executor.layers.quantization.utils.quant_utils import ( | ||
| GroupShape) | ||
| from vllm.platforms import current_platform | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "batch_size,hidden_dim,group_size", | ||
| [ | ||
| (16, 256, 32), # Small | ||
| (64, 1024, 64), # Medium | ||
| (128, 2048, 128), # Large | ||
| (8, 513, 64), # Non-divisible (native only) | ||
| ]) | ||
| @pytest.mark.parametrize("seed", [42]) | ||
| @torch.inference_mode() | ||
| def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int, | ||
| group_size: int, seed: int) -> None: | ||
| """Test QuantFP8 group quantization with various configurations. | ||
|
|
||
| Tests both CUDA and native implementations, column-major scales, | ||
| and verifies consistency between implementations. | ||
| """ | ||
| current_platform.seed_everything(seed) | ||
|
|
||
| x = torch.randn( | ||
| (batch_size, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8 | ||
| expected_num_groups = (hidden_dim + group_size - 1) // group_size | ||
| is_divisible = hidden_dim % group_size == 0 | ||
|
|
||
| group_shape = GroupShape(1, group_size) | ||
| quant_op = QuantFP8(static=False, | ||
| group_shape=group_shape, | ||
| column_major_scales=False) | ||
|
|
||
| # 1. Test native implementation (always available) | ||
| x_quant_native, scales_native = quant_op.forward_native(x.clone()) | ||
| assert x_quant_native.shape == x.shape | ||
| assert scales_native.shape == (batch_size, expected_num_groups) | ||
tahsintunan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # 2. Test column-major scales configuration | ||
| quant_op_col = QuantFP8(static=False, | ||
| group_shape=group_shape, | ||
| column_major_scales=True) | ||
| _, scales_col = quant_op_col.forward_native(x.clone()) | ||
| assert scales_col.shape == (expected_num_groups, batch_size) | ||
|
|
||
| # 3. Test CUDA implementation (only for divisible dimensions) | ||
| if is_divisible: | ||
| x_quant_cuda, scales_cuda = quant_op.forward_cuda(x.clone()) | ||
| assert x_quant_cuda.shape == x.shape | ||
| assert scales_cuda.shape == (batch_size, expected_num_groups) | ||
|
|
||
| # Verify CUDA/native consistency | ||
| assert torch.allclose(scales_cuda, scales_native, rtol=1e-9, atol=1e-8) | ||
|
|
||
| # Quantized values should mostly match | ||
| diff_count = (x_quant_cuda != x_quant_native).sum().item() | ||
| diff_ratio = diff_count / x_quant_cuda.numel() | ||
| assert diff_ratio < 0.002, f"Too many differences: {diff_ratio:.4%}" | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("seed", [42]) | ||
| @torch.inference_mode() | ||
| def test_quantfp8_group_multidimensional(seed: int) -> None: | ||
| current_platform.seed_everything(seed) | ||
|
|
||
| group_size = 64 | ||
|
|
||
| # Test with 3D input | ||
| batch1, batch2, hidden_dim = 4, 8, 512 | ||
| x_3d = torch.randn( | ||
| (batch1, batch2, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8 | ||
|
|
||
| group_shape = GroupShape(1, group_size) | ||
| quant_op = QuantFP8(static=False, | ||
| group_shape=group_shape, | ||
| column_major_scales=False) | ||
|
|
||
| x_quant, scales = quant_op.forward_native(x_3d.clone()) | ||
| assert x_quant.shape == x_3d.shape | ||
| assert scales.shape == (batch1, batch2, hidden_dim // group_size) | ||
|
|
||
| # Test column_major_scales with multi-dim | ||
| quant_op_col = QuantFP8(static=False, | ||
| group_shape=group_shape, | ||
| column_major_scales=True) | ||
| _, scales_col = quant_op_col.forward_native(x_3d.clone()) | ||
| assert scales_col.shape == (batch1, hidden_dim // group_size, batch2) | ||
|
|
||
| # Test with 4D input | ||
| batch1, batch2, batch3, hidden_dim = 2, 3, 4, 256 | ||
| x_4d = torch.randn((batch1, batch2, batch3, hidden_dim), | ||
| dtype=torch.bfloat16, | ||
| device="cuda") * 8 | ||
|
|
||
| x_quant_4d, scales_4d = quant_op.forward_native(x_4d.clone()) | ||
| assert x_quant_4d.shape == x_4d.shape | ||
| assert scales_4d.shape == (batch1, batch2, batch3, | ||
| hidden_dim // group_size) | ||
|
|
||
| _, scales_4d_col = quant_op_col.forward_native(x_4d.clone()) | ||
| assert scales_4d_col.shape == (batch1, batch2, hidden_dim // group_size, | ||
| batch3) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("seed", [42]) | ||
| @torch.inference_mode() | ||
| def test_quantfp8_group_edge_cases(seed: int) -> None: | ||
| current_platform.seed_everything(seed) | ||
|
|
||
| batch_size = 16 | ||
| group_size = 64 | ||
|
|
||
| # Test with single group (group_size >= hidden_dim) | ||
| x_small = torch.randn( | ||
| (batch_size, 32), dtype=torch.bfloat16, device="cuda") * 8 | ||
| group_shape = GroupShape(1, group_size) | ||
| quant_op = QuantFP8(static=False, | ||
| group_shape=group_shape, | ||
| column_major_scales=False) | ||
|
|
||
| x_quant_small, scales_small = quant_op.forward_native(x_small.clone()) | ||
| assert x_quant_small.shape == x_small.shape | ||
| assert scales_small.shape == (batch_size, 1) | ||
|
|
||
| # Test with zero inputs | ||
| x_zero = torch.zeros((batch_size, 256), | ||
| dtype=torch.bfloat16, | ||
| device="cuda") | ||
| x_quant_zero, scales_zero = quant_op.forward_native(x_zero.clone()) | ||
| assert x_quant_zero.shape == x_zero.shape | ||
| assert (scales_zero > 0).all(), "Scales should be clamped to minimum" | ||
|
|
||
| # Test very large values | ||
| x_large = torch.full((batch_size, 256), | ||
| 1000.0, | ||
| dtype=torch.bfloat16, | ||
| device="cuda") | ||
| x_quant_large, scales_large = quant_op.forward_native(x_large.clone()) | ||
| assert x_quant_large.shape == x_large.shape | ||
| # FP8 max is typically 448 or 224, so scales should be > 1 | ||
| assert (scales_large > 1.0).all(), "Large values should have scales > 1" | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why aren't we compiling with dynamic=True? I don't think we should be targeting shape specialization since we won't use that in practice
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In practice we specialize on all shapes except the first dim (
num_tokens). Thewith_dyn_argmarks that shape as dynamic to fully simulate vLLM usage 👍