-
-
Notifications
You must be signed in to change notification settings - Fork 14.6k
[not ready for review] extend fp8 online quant with blockwise scaling #32485
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,6 +1,7 @@ | ||||||
| # SPDX-License-Identifier: Apache-2.0 | ||||||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||||
|
|
||||||
| from enum import Enum | ||||||
| from typing import TYPE_CHECKING, Any, Optional | ||||||
|
|
||||||
| import torch | ||||||
|
|
@@ -103,6 +104,13 @@ | |||||
| logger = init_logger(__name__) | ||||||
|
|
||||||
|
|
||||||
| class OnlineQuantScalingType(Enum): | ||||||
| # TODO(before land): align on naming and add descriptive comments | ||||||
| # to each enum value | ||||||
| TENSORWISE = "tensorwise" | ||||||
| BLOCKWISE = "blockwise" | ||||||
|
|
||||||
|
|
||||||
| class Fp8Config(QuantizationConfig): | ||||||
| """Config class for FP8.""" | ||||||
|
|
||||||
|
|
@@ -140,6 +148,11 @@ | |||||
| ) | ||||||
| self.weight_block_size = weight_block_size | ||||||
|
|
||||||
| # TODO(before land): hook this up to user UI, for now hardcode it here | ||||||
| self.online_quant_scaling_type = OnlineQuantScalingType.BLOCKWISE | ||||||
| # self.online_quant_scaling_type = OnlineQuantScalingType.TENSORWISE | ||||||
| self.online_block_size = [128, 128] # [block_n, block_k] | ||||||
|
|
||||||
| @classmethod | ||||||
| def get_name(cls) -> QuantizationMethods: | ||||||
| return "fp8" | ||||||
|
|
@@ -328,20 +341,33 @@ | |||||
| self.weight_block_size = self.quant_config.weight_block_size | ||||||
| self.block_quant = self.weight_block_size is not None | ||||||
| self.act_q_static = self.quant_config.activation_scheme == "static" | ||||||
|
|
||||||
| if self.weight_block_size: | ||||||
| self.act_q_group_shape = GroupShape(1, self.weight_block_size[0]) | ||||||
| elif ( | ||||||
| self.quant_config.online_quant_scaling_type | ||||||
| is OnlineQuantScalingType.BLOCKWISE | ||||||
| ): | ||||||
| self.act_q_group_shape = GroupShape( | ||||||
| 1, self.quant_config.online_block_size[0] | ||||||
| ) | ||||||
| else: | ||||||
| # Use per-token quantization for better perf if dynamic and cutlass | ||||||
| if not self.act_q_static and cutlass_fp8_supported(): | ||||||
| self.act_q_group_shape = GroupShape.PER_TOKEN | ||||||
| else: | ||||||
| self.act_q_group_shape = GroupShape.PER_TENSOR | ||||||
|
|
||||||
| if self.block_quant: | ||||||
| if ( | ||||||
| self.block_quant | ||||||
| or self.quant_config.online_quant_scaling_type | ||||||
| is OnlineQuantScalingType.BLOCKWISE | ||||||
| ): | ||||||
|
Comment on lines
+361
to
+365
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The condition @property
def _is_blockwise_quant(self):
return (self.block_quant or
self.quant_config.online_quant_scaling_type is
OnlineQuantScalingType.BLOCKWISE)You could then use |
||||||
| block_size = self.weight_block_size or self.quant_config.online_block_size | ||||||
| assert block_size is not None | ||||||
| assert not self.act_q_static | ||||||
| assert self.weight_block_size is not None | ||||||
| self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( | ||||||
| weight_group_shape=GroupShape(*self.weight_block_size), | ||||||
| weight_group_shape=GroupShape(*block_size), | ||||||
| act_quant_group_shape=self.act_q_group_shape, | ||||||
| cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, | ||||||
| use_aiter_and_is_supported=self.use_aiter_and_is_supported, | ||||||
|
|
@@ -487,8 +513,38 @@ | |||||
| # If checkpoint not serialized fp8, quantize the weights. | ||||||
| else: | ||||||
| if not self.quant_config.is_checkpoint_fp8_serialized: | ||||||
| qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None) | ||||||
| weight = qweight.t() | ||||||
| # Online quantization | ||||||
| if ( | ||||||
| self.quant_config.online_quant_scaling_type | ||||||
| is OnlineQuantScalingType.BLOCKWISE | ||||||
| ): | ||||||
| # blockwise | ||||||
| from vllm.utils.deep_gemm import per_block_cast_to_fp8 | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
|
|
||||||
| block_size = self.quant_config.online_block_size | ||||||
| # layer.weight is [N, K] where N=output_size, K=input_size | ||||||
| qweight, weight_scale_inv = per_block_cast_to_fp8( | ||||||
| layer.weight, block_size=block_size | ||||||
| ) | ||||||
| # qweight: [N, K] in FP8 | ||||||
| # weight_scale_inv: [N/block_n, K/block_k] - inverse scales | ||||||
| # Note: block ops expect [N, K] format (no transpose) | ||||||
| replace_parameter(layer, "weight", qweight.data) | ||||||
| replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data) | ||||||
| layer.weight_block_size = block_size | ||||||
| size_k_first = False | ||||||
| else: | ||||||
| # tensorwise | ||||||
| assert ( | ||||||
| self.quant_config.online_quant_scaling_type | ||||||
| is OnlineQuantScalingType.TENSORWISE | ||||||
| ) | ||||||
| qweight, weight_scale = ops.scaled_fp8_quant( | ||||||
| layer.weight, scale=None | ||||||
| ) | ||||||
| weight = qweight.t() | ||||||
| replace_parameter(layer, "weight", weight.data) | ||||||
| replace_parameter(layer, "weight_scale", weight_scale.data) | ||||||
|
|
||||||
| # If checkpoint is fp8 per-tensor, handle that there are N scales for N | ||||||
| # shards in a fused module | ||||||
|
|
@@ -512,9 +568,9 @@ | |||||
| input_scale = input_scale.max() | ||||||
| weight = weight.t() | ||||||
|
|
||||||
| # Update layer with new values. | ||||||
| replace_parameter(layer, "weight", weight.data) | ||||||
| replace_parameter(layer, "weight_scale", weight_scale.data) | ||||||
| # Update layer with new values. | ||||||
| replace_parameter(layer, "weight", weight.data) | ||||||
| replace_parameter(layer, "weight_scale", weight_scale.data) | ||||||
|
|
||||||
| if input_scale is not None: | ||||||
| replace_parameter(layer, "input_scale", input_scale) | ||||||
|
|
@@ -529,7 +585,11 @@ | |||||
| del layer.input_scale | ||||||
| return | ||||||
|
|
||||||
| if self.block_quant: | ||||||
| if ( | ||||||
| self.block_quant | ||||||
| or self.quant_config.online_quant_scaling_type | ||||||
| is OnlineQuantScalingType.BLOCKWISE | ||||||
| ): | ||||||
| maybe_post_process_fp8_weight_block(layer) | ||||||
|
|
||||||
| def apply( | ||||||
|
|
@@ -541,8 +601,11 @@ | |||||
| # if batch invariant mode is enabled, prefer DeepGEMM FP8 path | ||||||
| # we will use BF16 dequant when DeepGEMM is not supported. | ||||||
| if vllm_is_batch_invariant(): | ||||||
| if self.block_quant: | ||||||
| assert self.weight_block_size is not None | ||||||
| if ( | ||||||
| self.block_quant | ||||||
| or self.quant_config.online_quant_scaling_type | ||||||
| is OnlineQuantScalingType.BLOCKWISE | ||||||
| ): | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Blockwise condition breaks serialized per-tensor FP8 checkpointsHigh Severity The conditions checking Additional Locations (1) |
||||||
| return self.w8a8_block_fp8_linear.apply( | ||||||
| input=x, | ||||||
| weight=layer.weight, | ||||||
|
|
@@ -591,9 +654,11 @@ | |||||
| bias=bias, | ||||||
| ) | ||||||
|
|
||||||
| if self.block_quant: | ||||||
| assert self.weight_block_size is not None | ||||||
|
|
||||||
| if ( | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Marlin path accesses non-existent weight scale attributeHigh Severity The Marlin code path in Additional Locations (1) |
||||||
| self.block_quant | ||||||
| or self.quant_config.online_quant_scaling_type | ||||||
| is OnlineQuantScalingType.BLOCKWISE | ||||||
| ): | ||||||
| return self.w8a8_block_fp8_linear.apply( | ||||||
| input=x, | ||||||
| weight=layer.weight, | ||||||
|
|
@@ -1089,7 +1154,21 @@ | |||||
| super().__init__(quant_config, layer) | ||||||
| assert not quant_config.is_checkpoint_fp8_serialized | ||||||
| assert quant_config.activation_scheme == "dynamic" | ||||||
| assert quant_config.weight_block_size is None | ||||||
|
|
||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The assertion
Suggested change
|
||||||
| # Override parent class attributes for online blockwise quantization | ||||||
| if ( | ||||||
| self.quant_config.online_quant_scaling_type | ||||||
| is OnlineQuantScalingType.BLOCKWISE | ||||||
| ): | ||||||
| self.weight_block_size = self.quant_config.online_block_size | ||||||
| self.block_quant = True | ||||||
| self.weight_scale_name = "weight_scale_inv" | ||||||
| # Re-select backend with correct block_quant flag | ||||||
| self.fp8_backend = select_fp8_moe_backend( | ||||||
| block_quant=self.block_quant, | ||||||
|
Check failure on line 1168 in vllm/model_executor/layers/quantization/fp8.py
|
||||||
| tp_size=layer.moe_parallel_config.tp_size, | ||||||
| with_lora_support=self.moe.is_lora_enabled, | ||||||
| ) | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Parent validation rejects online blockwise before child overridesMedium Severity
Additional Locations (1) |
||||||
|
|
||||||
| def create_weights( | ||||||
| self, | ||||||
|
|
@@ -1168,16 +1247,43 @@ | |||||
| set_weight_attrs(w2_weight, extra_weight_attrs) | ||||||
|
|
||||||
| # WEIGHT_SCALES | ||||||
| # Allocate 2 scales for w1 and w3 respectively. | ||||||
| # They will be combined to a single scale after weight loading. | ||||||
| w13_weight_scale = torch.nn.Parameter( | ||||||
| torch.ones(num_experts, dtype=torch.float32), requires_grad=False | ||||||
| ) | ||||||
| w2_weight_scale = torch.nn.Parameter( | ||||||
| torch.ones(num_experts, dtype=torch.float32), requires_grad=False | ||||||
| ) | ||||||
| layer.register_parameter("w13_weight_scale", w13_weight_scale) | ||||||
| layer.register_parameter("w2_weight_scale", w2_weight_scale) | ||||||
| if ( | ||||||
| self.quant_config.online_quant_scaling_type | ||||||
| is OnlineQuantScalingType.BLOCKWISE | ||||||
| ): | ||||||
| # For blockwise, scales are per block (typically 128x128) | ||||||
| block_size = self.quant_config.online_block_size | ||||||
| block_n, block_k = block_size[0], block_size[1] | ||||||
| w13_weight_scale = torch.nn.Parameter( | ||||||
| torch.ones( | ||||||
| num_experts, | ||||||
| 2 * ((intermediate_size_per_partition + block_n - 1) // block_n), | ||||||
| (hidden_size + block_k - 1) // block_k, | ||||||
| dtype=torch.float32, | ||||||
| ), | ||||||
| requires_grad=False, | ||||||
| ) | ||||||
| w2_weight_scale = torch.nn.Parameter( | ||||||
| torch.ones( | ||||||
| num_experts, | ||||||
| (hidden_size + block_n - 1) // block_n, | ||||||
| (intermediate_size_per_partition + block_k - 1) // block_k, | ||||||
| dtype=torch.float32, | ||||||
| ), | ||||||
| requires_grad=False, | ||||||
| ) | ||||||
| layer.register_parameter("w13_weight_scale_inv", w13_weight_scale) | ||||||
| layer.register_parameter("w2_weight_scale_inv", w2_weight_scale) | ||||||
| else: | ||||||
| # For tensorwise, scales are per expert | ||||||
| w13_weight_scale = torch.nn.Parameter( | ||||||
| torch.ones(num_experts, dtype=torch.float32), requires_grad=False | ||||||
| ) | ||||||
| w2_weight_scale = torch.nn.Parameter( | ||||||
| torch.ones(num_experts, dtype=torch.float32), requires_grad=False | ||||||
| ) | ||||||
| layer.register_parameter("w13_weight_scale", w13_weight_scale) | ||||||
| layer.register_parameter("w2_weight_scale", w2_weight_scale) | ||||||
| set_weight_attrs(w13_weight_scale, extra_weight_attrs) | ||||||
| set_weight_attrs(w2_weight_scale, extra_weight_attrs) | ||||||
|
|
||||||
|
|
@@ -1192,16 +1298,39 @@ | |||||
| fp8_dtype = current_platform.fp8_dtype() | ||||||
| w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype) | ||||||
| w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype) | ||||||
| w13_scale = layer.w13_weight_scale | ||||||
| w2_scale = layer.w2_weight_scale | ||||||
|
|
||||||
| for expert in range(layer.local_num_experts): | ||||||
| w13[expert, :, :], w13_scale[expert] = ops.scaled_fp8_quant( | ||||||
| layer.w13_weight[expert, :, :] | ||||||
| ) | ||||||
| w2[expert, :, :], w2_scale[expert] = ops.scaled_fp8_quant( | ||||||
| layer.w2_weight[expert, :, :] | ||||||
| ) | ||||||
| if ( | ||||||
| self.quant_config.online_quant_scaling_type | ||||||
| is OnlineQuantScalingType.BLOCKWISE | ||||||
| ): | ||||||
| # Blockwise quantization | ||||||
| from vllm.utils.deep_gemm import per_block_cast_to_fp8 | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
|
|
||||||
| block_size = self.quant_config.online_block_size | ||||||
| w13_scale = layer.w13_weight_scale_inv | ||||||
| w2_scale = layer.w2_weight_scale_inv | ||||||
|
|
||||||
| for expert in range(layer.local_num_experts): | ||||||
| w13[expert, :, :], w13_scale[expert, :, :] = per_block_cast_to_fp8( | ||||||
| layer.w13_weight[expert, :, :], block_size=block_size | ||||||
| ) | ||||||
| w2[expert, :, :], w2_scale[expert, :, :] = per_block_cast_to_fp8( | ||||||
| layer.w2_weight[expert, :, :], block_size=block_size | ||||||
| ) | ||||||
|
|
||||||
| layer.weight_block_size = block_size | ||||||
| else: | ||||||
| # Tensorwise quantization | ||||||
| w13_scale = layer.w13_weight_scale | ||||||
| w2_scale = layer.w2_weight_scale | ||||||
|
|
||||||
| for expert in range(layer.local_num_experts): | ||||||
| w13[expert, :, :], w13_scale[expert] = ops.scaled_fp8_quant( | ||||||
| layer.w13_weight[expert, :, :] | ||||||
| ) | ||||||
| w2[expert, :, :], w2_scale[expert] = ops.scaled_fp8_quant( | ||||||
| layer.w2_weight[expert, :, :] | ||||||
| ) | ||||||
|
|
||||||
| # Shuffle weights to runtime format and setup kernel. | ||||||
| self._setup_kernel( | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test validation disabled with debug model change
Medium Severity
The test model was changed from
facebook/opt-125mtoQwen/Qwen1.5-MoE-A2.7B, but thecheck_modelvalidation function (which references opt-125m-specific layer paths likemodel.model.decoder.layers[0].fc1) was commented out instead of updated. The test now only runs inference without validating quantization was applied correctly. Additionally, test parameterization was reduced, decreasing coverage.Additional Locations (2)
tests/quantization/test_fp8.py#L194-L198tests/quantization/test_fp8.py#L128-L133