-
-
Notifications
You must be signed in to change notification settings - Fork 11.8k
[FP8] Extend per-token-group quantization support to QuantFP8 #24342
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 3 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
b0b9d48
add per-token-group quantization support to QuantFP8
tahsintunan 74bd084
Update vllm/model_executor/layers/quantization/utils/quant_utils.py
tahsintunan b50d163
Add PyTorch implementation for QuantFP8 group quantization
tahsintunan 2662be1
refactor: move FP8 quantization functions into QuantFP8
tahsintunan 4fe4578
Refactor benchmark to support all group shapes
ProExpertProg 100b11c
refactor: clean up QuantFP8 forward methods and consolidate tests
tahsintunan dd45227
refactor: test_fp8_quant_group to avoid mypy type errors
tahsintunan ff0855a
bench: add CLI args for FP8 benchmark configuration
tahsintunan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,148 @@ | ||
| #!/usr/bin/env python | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| """Benchmark for QuantFP8 Group Quantization implementation.""" | ||
|
|
||
| import argparse | ||
|
|
||
| import torch | ||
|
|
||
| from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 | ||
| from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape | ||
| from vllm.platforms import current_platform | ||
|
|
||
|
|
||
| def _time_cuda( | ||
| fn, | ||
| warmup_iters: int, | ||
| bench_iters: int, | ||
| ) -> float: | ||
| # warmup | ||
| for _ in range(warmup_iters): | ||
| fn() | ||
| torch.cuda.synchronize() | ||
|
|
||
| start = torch.cuda.Event(enable_timing=True) | ||
| end = torch.cuda.Event(enable_timing=True) | ||
|
|
||
| start.record() | ||
| for _ in range(bench_iters): | ||
| fn() | ||
| end.record() | ||
| torch.cuda.synchronize() | ||
|
|
||
| return start.elapsed_time(end) / bench_iters # ms/iter | ||
|
|
||
|
|
||
| def run_benchmark( | ||
| shape: tuple[int, int], | ||
| group_size: int, | ||
| column_major: bool, | ||
| warmup_iters: int, | ||
| bench_iters: int, | ||
| ) -> None: | ||
| """Benchmark QuantFP8 with group quantization using different backends.""" | ||
| num_tokens, hidden_dim = shape | ||
|
|
||
| device = torch.device("cuda") | ||
| torch.manual_seed(42) | ||
| x = torch.randn(num_tokens, hidden_dim, device=device, dtype=torch.bfloat16) * 8 | ||
|
|
||
| group_shape = GroupShape(1, group_size) | ||
| quant_op = QuantFP8( | ||
| static=False, group_shape=group_shape, column_major_scales=column_major | ||
| ) | ||
|
|
||
| def cuda_impl(): | ||
| return quant_op.forward_cuda(x.clone()) | ||
|
|
||
| def native_impl(): | ||
| return quant_op.forward_native(x.clone()) | ||
|
|
||
| cuda_ms = _time_cuda(cuda_impl, warmup_iters, bench_iters) | ||
| native_ms = _time_cuda(native_impl, warmup_iters, bench_iters) | ||
|
|
||
| speedup = cuda_ms / native_ms if native_ms else 0 | ||
|
|
||
| cfg_desc = f"shape={shape} gs={group_size:<3} col_major={column_major}" | ||
| print(f"{cfg_desc:45} | {cuda_ms:7.3f} | {native_ms:7.3f} | {speedup:6.2f}x") | ||
|
|
||
|
|
||
| def parse_args(): | ||
| parser = argparse.ArgumentParser( | ||
| description="Benchmark QuantFP8 group quantization implementation" | ||
| ) | ||
| parser.add_argument( | ||
| "--warmup-iters", type=int, default=10, help="Number of warmup iterations" | ||
| ) | ||
| parser.add_argument( | ||
| "--bench-iters", type=int, default=100, help="Number of benchmark iterations" | ||
| ) | ||
| parser.add_argument( | ||
| "--shapes", | ||
| type=str, | ||
| default="32,128;64,256;16,512;128,1024;256,2048", | ||
| help="Shapes to benchmark as 'tokens,hidden;...' (default: multiple shapes)", | ||
| ) | ||
| parser.add_argument( | ||
| "--group-sizes", | ||
| type=str, | ||
| default="64,128", | ||
| help="Group sizes to benchmark (comma-separated)", | ||
| ) | ||
| parser.add_argument( | ||
| "--no-column-major", | ||
| action="store_true", | ||
| help="Skip column-major scale benchmarks", | ||
| ) | ||
| return parser.parse_args() | ||
|
|
||
|
|
||
| def main(): | ||
| if not current_platform.is_cuda(): | ||
| raise RuntimeError("CUDA device is required to run this benchmark.") | ||
|
|
||
| args = parse_args() | ||
|
|
||
| shapes = [] | ||
| for shape_str in args.shapes.split(";"): | ||
| tokens, hidden = map(int, shape_str.split(",")) | ||
| shapes.append((tokens, hidden)) | ||
|
|
||
| group_sizes = list(map(int, args.group_sizes.split(","))) | ||
|
|
||
| print("\n" + "=" * 80) | ||
| print("QuantFP8 Group Quantization Benchmark (CUDA kernel vs PyTorch native)") | ||
| print("=" * 80) | ||
| print(f"Device: {torch.cuda.get_device_name()}") | ||
| print(f"Warmup iterations: {args.warmup_iters}") | ||
| print(f"Benchmark iterations: {args.bench_iters}") | ||
| print("=" * 80) | ||
|
|
||
| print(f"{'Configuration':45} | {'CUDA':^9} | {'Native':^9} | {'Speedup':^8}") | ||
| print("-" * 80) | ||
|
|
||
| for shape in shapes: | ||
| for gs in group_sizes: | ||
| run_benchmark( | ||
| shape, | ||
| gs, | ||
| column_major=False, | ||
| warmup_iters=args.warmup_iters, | ||
| bench_iters=args.bench_iters, | ||
| ) | ||
|
|
||
| if not args.no_column_major: | ||
| run_benchmark( | ||
| shape, | ||
| gs, | ||
| column_major=True, | ||
| warmup_iters=args.warmup_iters, | ||
| bench_iters=args.bench_iters, | ||
| ) | ||
|
|
||
| print("=" * 80) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,206 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| """Tests for QuantFP8 Group Quantization implementation.""" | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 | ||
| from vllm.model_executor.layers.quantization.utils.quant_utils import ( | ||
| GroupShape) | ||
| from vllm.platforms import current_platform | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("batch_size", [16, 32]) | ||
| @pytest.mark.parametrize("hidden_dim", | ||
| [256, 512, 513]) # Include non-divisible | ||
| @pytest.mark.parametrize("group_size", [32, 64, 128]) | ||
| @pytest.mark.parametrize("seed", [42]) | ||
| @torch.inference_mode() | ||
| def test_quantfp8_group_basic(batch_size: int, hidden_dim: int, | ||
| group_size: int, seed: int) -> None: | ||
| current_platform.seed_everything(seed) | ||
|
|
||
| x = torch.randn( | ||
| (batch_size, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8 | ||
|
|
||
| # Create QuantFP8 with group quantization | ||
| group_shape = GroupShape(1, group_size) | ||
| quant_op = QuantFP8(static=False, | ||
| group_shape=group_shape, | ||
| column_major_scales=False) | ||
|
|
||
| expected_num_groups = (hidden_dim + group_size - 1) // group_size | ||
|
|
||
| # Test CUDA implementation (only supports divisible dimensions) | ||
| if hidden_dim % group_size == 0: | ||
| x_quant_cuda, scales_cuda = quant_op.forward_cuda(x.clone()) | ||
| assert x_quant_cuda.shape == x.shape | ||
| assert scales_cuda.shape == (batch_size, expected_num_groups) | ||
|
|
||
| # Test PyTorch native implementation | ||
| x_quant_native, scales_native = quant_op.forward_native(x.clone()) | ||
| assert x_quant_native.shape == x.shape | ||
| assert scales_native.shape == (batch_size, expected_num_groups) | ||
tahsintunan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # Test column_major_scales | ||
| quant_op_col = QuantFP8(static=False, | ||
| group_shape=group_shape, | ||
| column_major_scales=True) | ||
| _, scales_col = quant_op_col.forward_native(x.clone()) | ||
| assert scales_col.shape == (expected_num_groups, batch_size) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("seed", [42]) | ||
| @torch.inference_mode() | ||
| def test_quantfp8_group_multidimensional(seed: int) -> None: | ||
| current_platform.seed_everything(seed) | ||
|
|
||
| group_size = 64 | ||
|
|
||
| # Test with 3D input | ||
| batch1, batch2, hidden_dim = 4, 8, 512 | ||
| x_3d = torch.randn( | ||
| (batch1, batch2, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8 | ||
|
|
||
| group_shape = GroupShape(1, group_size) | ||
| quant_op = QuantFP8(static=False, | ||
| group_shape=group_shape, | ||
| column_major_scales=False) | ||
|
|
||
| x_quant, scales = quant_op.forward_native(x_3d.clone()) | ||
| assert x_quant.shape == x_3d.shape | ||
| assert scales.shape == (batch1, batch2, hidden_dim // group_size) | ||
|
|
||
| # Test column_major_scales with multi-dim | ||
| quant_op_col = QuantFP8(static=False, | ||
| group_shape=group_shape, | ||
| column_major_scales=True) | ||
| _, scales_col = quant_op_col.forward_native(x_3d.clone()) | ||
| assert scales_col.shape == (batch1, hidden_dim // group_size, batch2) | ||
|
|
||
| # Test with 4D input | ||
| batch1, batch2, batch3, hidden_dim = 2, 3, 4, 256 | ||
| x_4d = torch.randn((batch1, batch2, batch3, hidden_dim), | ||
| dtype=torch.bfloat16, | ||
| device="cuda") * 8 | ||
|
|
||
| x_quant_4d, scales_4d = quant_op.forward_native(x_4d.clone()) | ||
| assert x_quant_4d.shape == x_4d.shape | ||
| assert scales_4d.shape == (batch1, batch2, batch3, | ||
| hidden_dim // group_size) | ||
|
|
||
| _, scales_4d_col = quant_op_col.forward_native(x_4d.clone()) | ||
| assert scales_4d_col.shape == (batch1, batch2, hidden_dim // group_size, | ||
| batch3) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("batch_size", [32]) | ||
| @pytest.mark.parametrize("hidden_dim", [1024]) | ||
| @pytest.mark.parametrize("group_size", [128]) | ||
| @pytest.mark.parametrize("seed", [42]) | ||
| @torch.inference_mode() | ||
| def test_quantfp8_group_cuda_native_consistency(batch_size: int, | ||
| hidden_dim: int, | ||
| group_size: int, | ||
| seed: int) -> None: | ||
| """Compare CUDA and native implementations for consistency.""" | ||
| current_platform.seed_everything(seed) | ||
|
|
||
| x = torch.randn( | ||
| (batch_size, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8 | ||
|
|
||
| group_shape = GroupShape(1, group_size) | ||
| quant_op = QuantFP8(static=False, | ||
| group_shape=group_shape, | ||
| column_major_scales=False) | ||
|
|
||
| # Run both implementations | ||
| x_quant_cuda, scales_cuda = quant_op.forward_cuda(x.clone()) | ||
| x_quant_native, scales_native = quant_op.forward_native(x.clone()) | ||
|
|
||
| # Check shapes match | ||
| assert x_quant_cuda.shape == x_quant_native.shape | ||
| assert scales_cuda.shape == scales_native.shape | ||
|
|
||
| # Scales should match | ||
| assert torch.allclose(scales_cuda, scales_native, rtol=1e-9, atol=1e-8) | ||
|
|
||
| # Quantized values should mostly match, with rare rounding differences | ||
| # FP8 rounding at boundaries can differ between CUDA and PyTorch | ||
| diff_count = (x_quant_cuda != x_quant_native).sum().item() | ||
| diff_ratio = diff_count / x_quant_cuda.numel() | ||
| assert diff_ratio < 0.002, f"Too many differences: {diff_ratio:.4%}" | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("seed", [42]) | ||
| @torch.inference_mode() | ||
| def test_quantfp8_group_edge_cases(seed: int) -> None: | ||
| current_platform.seed_everything(seed) | ||
|
|
||
| batch_size = 16 | ||
| group_size = 64 | ||
|
|
||
| # Test with single group (group_size >= hidden_dim) | ||
| x_small = torch.randn( | ||
| (batch_size, 32), dtype=torch.bfloat16, device="cuda") * 8 | ||
| group_shape = GroupShape(1, group_size) | ||
| quant_op = QuantFP8(static=False, | ||
| group_shape=group_shape, | ||
| column_major_scales=False) | ||
|
|
||
| x_quant_small, scales_small = quant_op.forward_native(x_small.clone()) | ||
| assert x_quant_small.shape == x_small.shape | ||
| assert scales_small.shape == (batch_size, 1) | ||
|
|
||
| # Test with zero inputs | ||
| x_zero = torch.zeros((batch_size, 256), | ||
| dtype=torch.bfloat16, | ||
| device="cuda") | ||
| x_quant_zero, scales_zero = quant_op.forward_native(x_zero.clone()) | ||
| assert x_quant_zero.shape == x_zero.shape | ||
| assert (scales_zero > 0).all(), "Scales should be clamped to minimum" | ||
|
|
||
| # Test very large values | ||
| x_large = torch.full((batch_size, 256), | ||
| 1000.0, | ||
| dtype=torch.bfloat16, | ||
| device="cuda") | ||
| x_quant_large, scales_large = quant_op.forward_native(x_large.clone()) | ||
| assert x_quant_large.shape == x_large.shape | ||
| # FP8 max is typically 448 or 224, so scales should be > 1 | ||
| assert (scales_large > 1.0).all(), "Large values should have scales > 1" | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "batch_size,hidden_dim,group_size", | ||
| [ | ||
| (16, 256, 16), # Small | ||
| (64, 1024, 64), # Medium | ||
| (128, 2048, 128), # Large | ||
| (8, 513, 64), # Non-divisible (native only) | ||
| ]) | ||
| @pytest.mark.parametrize("seed", [42]) | ||
| @torch.inference_mode() | ||
| def test_quantfp8_group_various_configs(batch_size: int, hidden_dim: int, | ||
| group_size: int, seed: int) -> None: | ||
| current_platform.seed_everything(seed) | ||
|
|
||
| x = torch.randn( | ||
| (batch_size, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8 | ||
| group_shape = GroupShape(1, group_size) | ||
| quant_op = QuantFP8(static=False, | ||
| group_shape=group_shape, | ||
| column_major_scales=False) | ||
|
|
||
| expected_num_groups = (hidden_dim + group_size - 1) // group_size | ||
|
|
||
| x_quant_native, scales_native = quant_op.forward_native(x.clone()) | ||
| assert x_quant_native.shape == x.shape | ||
| assert scales_native.shape == (batch_size, expected_num_groups) | ||
|
|
||
| if hidden_dim % group_size == 0: | ||
| x_quant_cuda, scales_cuda = quant_op.forward_cuda(x.clone()) | ||
| assert x_quant_cuda.shape == x.shape | ||
| assert scales_cuda.shape == (batch_size, expected_num_groups) | ||
| assert torch.allclose(scales_cuda, scales_native, rtol=1e-9, atol=1e-8) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.