-
Notifications
You must be signed in to change notification settings - Fork 6.1k
w4a16 nvfp14 quant support #25535
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
w4a16 nvfp14 quant support #25535
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,157 @@ | ||||||
| # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors | ||||||
| # SPDX-License-Identifier: Apache-2.0 | ||||||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||||
| import logging | ||||||
| from collections.abc import Callable | ||||||
| from typing import Optional | ||||||
|
|
||||||
| import torch | ||||||
| from torch.nn.parameter import Parameter | ||||||
|
|
||||||
| from sglang.srt.layers.parameter import ( | ||||||
| GroupQuantScaleParameter, | ||||||
| ModelWeightParameter, | ||||||
| PerTensorScaleParameter, | ||||||
| ) | ||||||
| from sglang.srt.layers.quantization.compressed_tensors.schemes import ( | ||||||
| CompressedTensorsLinearScheme, | ||||||
| ) | ||||||
| from sglang.srt.layers.quantization.fp4_utils import get_fp4_gemm_runner_backend | ||||||
| from sglang.srt.layers.quantization.modelopt_quant import ( | ||||||
| enable_flashinfer_fp4_gemm, | ||||||
| fp4_gemm, | ||||||
| fp4_quantize, | ||||||
| ) | ||||||
| from sglang.srt.layers.quantization.utils import swizzle_blockscale | ||||||
|
|
||||||
| logger = logging.getLogger(__name__) | ||||||
|
|
||||||
| __all__ = ["CompressedTensorsW4A16Fp4"] | ||||||
|
|
||||||
|
|
||||||
| class CompressedTensorsW4A16Fp4(CompressedTensorsLinearScheme): | ||||||
| """weight-only NVFP4 quantization (w4a16).""" | ||||||
|
|
||||||
| def __init__(self): | ||||||
| self.group_size = 16 | ||||||
|
|
||||||
| @classmethod | ||||||
| def get_min_capability(cls) -> int: | ||||||
| return 100 | ||||||
|
|
||||||
| def create_weights( | ||||||
| self, | ||||||
| layer: torch.nn.Module, | ||||||
| output_partition_sizes: list[int], | ||||||
| input_size_per_partition: int, | ||||||
| params_dtype: torch.dtype, | ||||||
| weight_loader: Callable, | ||||||
| **kwargs, | ||||||
| ): | ||||||
| output_size_per_partition = sum(output_partition_sizes) | ||||||
| layer.logical_widths = output_partition_sizes | ||||||
| layer.input_size_per_partition = input_size_per_partition | ||||||
| layer.output_size_per_partition = output_size_per_partition | ||||||
|
|
||||||
| weight = ModelWeightParameter( | ||||||
| data=torch.empty( | ||||||
| sum(output_partition_sizes), | ||||||
| input_size_per_partition // 2, | ||||||
| dtype=torch.uint8, | ||||||
| ), | ||||||
| input_dim=1, | ||||||
| output_dim=0, | ||||||
| weight_loader=weight_loader, | ||||||
| ) | ||||||
| layer.register_parameter("weight_packed", weight) | ||||||
|
|
||||||
| weight_global_scale = PerTensorScaleParameter( | ||||||
| data=torch.empty(len(output_partition_sizes), dtype=torch.float32), | ||||||
| weight_loader=weight_loader, | ||||||
| ) | ||||||
| layer.register_parameter("weight_global_scale", weight_global_scale) | ||||||
|
|
||||||
| weight_scale = GroupQuantScaleParameter( | ||||||
| data=torch.empty( | ||||||
| sum(output_partition_sizes), | ||||||
| input_size_per_partition // self.group_size, | ||||||
| dtype=torch.float8_e4m3fn, | ||||||
| ), | ||||||
| input_dim=1, | ||||||
| output_dim=0, | ||||||
| weight_loader=weight_loader, | ||||||
| ) | ||||||
| layer.register_parameter("weight_scale", weight_scale) | ||||||
|
|
||||||
| def process_weights_after_loading(self, layer) -> None: | ||||||
| weight_gs = layer.weight_global_scale.max().to(torch.float32) | ||||||
| input_gs = (1.0 / weight_gs).to(torch.float32) | ||||||
| layer.input_global_scale = Parameter(input_gs, requires_grad=False) | ||||||
| layer.weight_global_scale = Parameter(weight_gs, requires_grad=False) | ||||||
|
|
||||||
| if get_fp4_gemm_runner_backend().is_flashinfer_trtllm(): | ||||||
| from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a | ||||||
|
|
||||||
| weight = layer.weight_packed.data | ||||||
| weight_scale = layer.weight_scale.data | ||||||
|
|
||||||
| epilogue_tile_m = 128 | ||||||
| weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m) | ||||||
| weight_scale = ( | ||||||
| shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m) | ||||||
| .reshape(weight_scale.shape) | ||||||
| .view(torch.float8_e4m3fn) | ||||||
| ) | ||||||
|
|
||||||
| layer.weight_scale = Parameter(weight_scale, requires_grad=False) | ||||||
| layer.weight_packed = Parameter(weight, requires_grad=False) | ||||||
|
Comment on lines
+92
to
+107
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This section is missing the necessary padding logic for the |
||||||
| else: | ||||||
| swizzled_weight_scale = swizzle_blockscale(layer.weight_scale) | ||||||
| layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False) | ||||||
| layer.weight_packed = Parameter( | ||||||
| layer.weight_packed.data, requires_grad=False | ||||||
| ) | ||||||
|
|
||||||
| layer.alpha = Parameter( | ||||||
| 1.0 / (input_gs * weight_gs), | ||||||
| requires_grad=False, | ||||||
| ) | ||||||
|
|
||||||
| def apply_weights( | ||||||
| self, | ||||||
| layer: torch.nn.Module, | ||||||
| x: torch.Tensor, | ||||||
| bias: Optional[torch.Tensor] = None, | ||||||
| ) -> torch.Tensor: | ||||||
| output_dtype = x.dtype | ||||||
| w_n, _ = layer.weight_packed.shape | ||||||
| output_shape = [x.shape[0], w_n] | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||
|
|
||||||
| x_fp4, x_blockscale = fp4_quantize(x, layer.input_global_scale) | ||||||
|
|
||||||
| assert x_fp4.dtype == torch.uint8 | ||||||
| assert layer.weight_packed.dtype == torch.uint8 | ||||||
| assert layer.weight_scale.dtype == torch.float8_e4m3fn | ||||||
| assert layer.alpha.dtype == torch.float32 | ||||||
|
|
||||||
| w = layer.weight_packed | ||||||
| w_blockscale = layer.weight_scale | ||||||
| if ( | ||||||
| enable_flashinfer_fp4_gemm | ||||||
| and not get_fp4_gemm_runner_backend().is_cutlass() | ||||||
| ): | ||||||
| w = layer.weight_packed.T | ||||||
| w_blockscale = layer.weight_scale.T | ||||||
|
|
||||||
| out = fp4_gemm( | ||||||
| x_fp4, | ||||||
| w, | ||||||
| x_blockscale, | ||||||
| w_blockscale, | ||||||
| layer.alpha, | ||||||
| output_dtype, | ||||||
| w_n, | ||||||
| ) | ||||||
|
Comment on lines
+146
to
+154
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||||||
| if bias is not None: | ||||||
| out = out + bias | ||||||
| return out.view(*output_shape) | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Synthesizing a static activation scale (
input_gs) from the weight scale (weight_gs) is mathematically incorrect forw4a16(weight-only) quantization. Since the checkpoint does not provide activation scales, the model should ideally use dynamic quantization for activations (calculating the scale from the inputxat runtime) to maintain accuracy. Using a fixed scale derived from weights will likely result in poor model performance.