-
-
Notifications
You must be signed in to change notification settings - Fork 11.8k
[FP8] Extend per-token-group quantization support to QuantFP8 #24342
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 2 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
b0b9d48
add per-token-group quantization support to QuantFP8
tahsintunan 74bd084
Update vllm/model_executor/layers/quantization/utils/quant_utils.py
tahsintunan b50d163
Add PyTorch implementation for QuantFP8 group quantization
tahsintunan 2662be1
refactor: move FP8 quantization functions into QuantFP8
tahsintunan 4fe4578
Refactor benchmark to support all group shapes
ProExpertProg 100b11c
refactor: clean up QuantFP8 forward methods and consolidate tests
tahsintunan dd45227
refactor: test_fp8_quant_group to avoid mypy type errors
tahsintunan ff0855a
bench: add CLI args for FP8 benchmark configuration
tahsintunan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,40 +23,78 @@ | |
| @CustomOp.register("quant_fp8") | ||
| class QuantFP8(CustomOp): | ||
| """ | ||
| Quantize input tensor to per-tensor or per-token FP8. | ||
| Quantize input tensor to FP8 (per-tensor, per-token, or per-group). | ||
| This CustomOp supports both static and dynamic quantization. | ||
| """ | ||
|
|
||
| def __init__(self, | ||
| static: bool, | ||
| group_shape: GroupShape, | ||
| num_token_padding: Optional[int] = None): | ||
| num_token_padding: Optional[int] = None, | ||
| column_major_scales: bool = False): | ||
| """ | ||
|
|
||
| :param static: static or dynamic quantization | ||
| :param group_shape: quantization group shape (PER_TOKEN or PER_TENSOR) | ||
| :param num_token_padding: Pad the token dimension of output to this size | ||
| :param group_shape: quantization group shape (PER_TOKEN, PER_TENSOR, | ||
| or arbitrary block size) | ||
| :param num_token_padding: Pad the token dimension of output to this | ||
| size | ||
| :param column_major_scales: For group quantization, output scales in | ||
| column major format | ||
| """ | ||
| super().__init__() | ||
| self.num_token_padding = num_token_padding | ||
| assert group_shape in {GroupShape.PER_TOKEN, GroupShape.PER_TENSOR} | ||
| assert not static or group_shape == GroupShape.PER_TENSOR, \ | ||
| "Only per-tensor scales supported for static quantization." | ||
| self.static = static | ||
| self.group_shape = group_shape | ||
| self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN | ||
| self.num_token_padding = num_token_padding | ||
| self.column_major_scales = column_major_scales | ||
|
|
||
| self.is_group_quant = group_shape.is_per_group() | ||
| if self.is_group_quant: | ||
| assert not static, "Group quantization only supports dynamic mode" | ||
| self.group_size = group_shape.col | ||
| else: | ||
| assert group_shape in {GroupShape.PER_TOKEN, GroupShape.PER_TENSOR} | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add an assert that column_major_scales is False if non group? |
||
| assert not static or group_shape == GroupShape.PER_TENSOR, \ | ||
| "Only per-tensor scales supported for static quantization." | ||
| self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN | ||
|
|
||
| def _quantize_group(self, | ||
| x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | ||
| from vllm.model_executor.layers.quantization.utils.fp8_utils import ( | ||
| per_token_group_quant_fp8) | ||
| return per_token_group_quant_fp8( | ||
tahsintunan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| x, | ||
| group_size=self.group_size, | ||
| column_major_scales=self.column_major_scales, | ||
| dtype=_FP8_DTYPE) | ||
|
|
||
| def _compute_dynamic_scale( | ||
| self, x: torch.Tensor, | ||
tahsintunan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| scale_ub: Optional[torch.Tensor]) -> torch.Tensor: | ||
| if self.group_shape == GroupShape.PER_TOKEN: | ||
| x_max, _ = x.abs().max(dim=-1) | ||
| x_max = x_max.unsqueeze(-1).to(torch.float32) | ||
| if scale_ub is not None: | ||
| x_max = x_max.clamp(max=scale_ub) | ||
| else: | ||
| x_max = x.abs().max().unsqueeze(-1).to(torch.float32) | ||
|
|
||
| scale = x_max / _FP8_MAX | ||
| return scale.clamp(min=_FP8_MIN_SCALING_FACTOR) | ||
|
|
||
| def forward_cuda( | ||
| self, | ||
| x: torch.Tensor, | ||
| scale: Optional[torch.Tensor] = None, | ||
| scale_ub: Optional[torch.Tensor] = None, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| if self.is_group_quant: | ||
| assert scale is None, "Group quantization is always dynamic" | ||
| return self._quantize_group(x) | ||
|
|
||
| assert (scale is not None) == self.static | ||
| assert scale_ub is None or (not self.static and self.group_shape | ||
| == GroupShape.PER_TOKEN | ||
| and scale_ub.numel() == 1) | ||
|
|
||
| return ops.scaled_fp8_quant( | ||
| x, | ||
| scale, | ||
|
|
@@ -70,22 +108,17 @@ def forward_native( | |
| scale: Optional[torch.Tensor] = None, | ||
| scale_ub: Optional[torch.Tensor] = None, | ||
| ): | ||
| if self.is_group_quant: | ||
| assert scale is None, "Group quantization is always dynamic" | ||
| return self._quantize_group(x) | ||
|
|
||
| assert (scale is not None) == self.static | ||
| assert scale_ub is None or (not self.static and self.group_shape | ||
| == GroupShape.PER_TOKEN | ||
| and scale_ub.numel() == 1) | ||
|
|
||
| if scale is None: | ||
| if self.group_shape == GroupShape.PER_TOKEN: | ||
| x_max, _ = x.abs().max(dim=-1) | ||
| x_max = x_max.unsqueeze(-1).to(torch.float32) | ||
| if scale_ub is not None: | ||
| x_max = x_max.clamp(max=scale_ub) | ||
| else: | ||
| x_max = x.abs().max().unsqueeze(-1).to(torch.float32) | ||
|
|
||
| scale = x_max / _FP8_MAX | ||
| scale = scale.clamp(min=_FP8_MIN_SCALING_FACTOR) | ||
| scale = self._compute_dynamic_scale(x, scale_ub) | ||
|
|
||
| # Even for dynamic per-token scales, | ||
| # reciprocal performs slightly better than division | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.