-
Notifications
You must be signed in to change notification settings - Fork 5.9k
[4/N] Quantization Refactor: AWQ schemes and Kernel call and weight init split #21126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
sglang-npu-bot
merged 38 commits into
sgl-project:main
from
Alisehen:hyc/awq-scheme-refactor
Apr 30, 2026
Merged
Changes from all commits
Commits
Show all changes
38 commits
Select commit
Hold shift + click to select a range
6c29a1d
Refactor AWQ into schemes and backend kernels
Alisehen 0618e35
bug fix
Alisehen 90b4283
reduce extra method
Alisehen cdde14e
Merge branch 'sgl-project:main' into hyc/awq-scheme-refactor
Alisehen cf8bfeb
Merge remote-tracking branch 'origin/main' into hyc/awq-scheme-refactor
Alisehen 571b834
bug fix
Alisehen 9d89014
Merge branch 'main' into hyc/awq-scheme-refactor
Alisehen 0abe523
Merge branch 'main' of https://github.com/sgl-project/sglang into hyc…
Alisehen 2fd568d
follow the reviews
Alisehen 7fd5b63
Merge branch 'main' into hyc/awq-scheme-refactor
Alisehen ffa6ed5
fix ci
Alisehen c6bde30
Merge branch 'hyc/awq-scheme-refactor' of https://github.com/Alisehen…
Alisehen 1cfb42f
Merge branch 'main' into hyc/awq-scheme-refactor
Alisehen 4ee2dda
fix
Alisehen ab79808
fix
Alisehen f37729c
fix
Alisehen 0ed8dfa
Merge branch 'main' into hyc/awq-scheme-refactor
Alisehen 336a3cf
Merge branch 'main' of https://github.com/sgl-project/sglang into hyc…
Alisehen b1bf125
Fix AWQ scheme import ordering
Alisehen 27b39bc
Refactor AWQ CPU path into schemes
Alisehen c98b0f6
Merge branch 'main' into hyc/awq-scheme-refactor
Alisehen 37b08bb
Merge remote-tracking branch 'origin/hyc/awq-scheme-refactor' into hy…
Alisehen ff7c24f
Format AWQ marlin fallback
Alisehen 9751bab
Merge branch 'main' into hyc/awq-scheme-refactor
Alisehen 178c91a
Address AWQ scheme review feedback
Alisehen b42cde5
Merge remote-tracking branch 'origin/hyc/awq-scheme-refactor' into hy…
Alisehen 4350cb6
Fix AWQ dequant test import
Alisehen e04e8f6
Merge branch 'main' into hyc/awq-scheme-refactor
Alisehen 825b704
bug fix
Alisehen aee44f8
Merge branch 'main' into hyc/awq-scheme-refactor
Alisehen c81189d
Merge branch 'main' into hyc/awq-scheme-refactor
Alisehen 0fdcebb
Merge branch 'main' into hyc/awq-scheme-refactor
Alisehen e93958a
Fix AWQ scheme setup and MoE native swiglu import
Alisehen c6255cb
Merge branch 'main' into hyc/awq-scheme-refactor
Alisehen e5088b5
Merge branch 'main' into hyc/awq-scheme-refactor
Alisehen 513bbea
Merge branch 'main' of https://github.com/sgl-project/sglang into hyc…
Alisehen 1cc3126
Merge remote-tracking branch 'origin/hyc/awq-scheme-refactor' into hy…
Alisehen 1b0145c
Use upstream GPT-OSS SwiGLU helper
Alisehen File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
255 changes: 255 additions & 0 deletions
255
python/sglang/srt/hardware_backend/gpu/quantization/awq_kernels.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,255 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from typing import TYPE_CHECKING, Optional | ||
|
|
||
| import torch | ||
|
|
||
| from sglang.srt.layers.moe import MoeRunner | ||
| from sglang.srt.layers.moe.moe_runner.marlin import MarlinMoeQuantInfo | ||
| from sglang.srt.layers.quantization.marlin_utils import ( | ||
| apply_awq_marlin_linear, | ||
| awq_to_marlin_zero_points, | ||
| marlin_make_empty_g_idx, | ||
| marlin_make_workspace, | ||
| marlin_moe_permute_scales, | ||
| marlin_permute_scales, | ||
| moe_awq_to_marlin_zero_points, | ||
| ) | ||
| from sglang.srt.layers.quantization.utils import get_scalar_types, replace_parameter | ||
| from sglang.srt.utils import is_hip, is_xpu | ||
|
|
||
| if TYPE_CHECKING: | ||
| from sglang.srt.layers.moe.token_dispatcher import ( | ||
| CombineInput, | ||
| StandardDispatchOutput, | ||
| ) | ||
| from sglang.srt.layers.quantization.base_config import QuantizationConfig | ||
|
|
||
| awq_marlin_moe_repack = None | ||
| awq_marlin_repack = None | ||
|
|
||
|
|
||
| def _unsupported_awq_dequantize(*args, **kwargs): | ||
| raise RuntimeError("AWQ GPU kernels are unavailable on the current platform.") | ||
|
|
||
|
|
||
| awq_dequantize = _unsupported_awq_dequantize | ||
|
|
||
| if is_xpu(): | ||
| try: | ||
| from sgl_kernel import awq_dequantize | ||
| except ImportError: | ||
| pass | ||
| elif is_hip(): | ||
| try: | ||
| from sglang.srt.layers.quantization.awq.awq_triton import ( | ||
| awq_dequantize_triton as awq_dequantize, | ||
| ) | ||
| except ImportError: | ||
| pass | ||
| else: | ||
| try: | ||
| from sglang.jit_kernel.awq_dequantize import awq_dequantize | ||
| from sglang.jit_kernel.awq_marlin_repack import ( | ||
| awq_marlin_moe_repack, | ||
| awq_marlin_repack, | ||
| ) | ||
| from sglang.srt.utils.custom_op import register_custom_op_from_extern | ||
|
|
||
| awq_dequantize = register_custom_op_from_extern( | ||
| awq_dequantize, | ||
| fake_impl=lambda qweight, scales, qzeros: qweight.new_empty( | ||
| qweight.shape[:-1] + (qweight.shape[-1] * 8,), dtype=scales.dtype | ||
| ), | ||
| ) | ||
| except ImportError: | ||
| try: | ||
| from sglang.srt.layers.quantization.awq.awq_triton import ( | ||
| awq_dequantize_triton as awq_dequantize, | ||
| ) | ||
| except ImportError: | ||
| try: | ||
| from sgl_kernel import awq_dequantize | ||
| except ImportError: | ||
| pass | ||
|
|
||
| _, scalar_types = get_scalar_types() | ||
|
|
||
|
|
||
| class AWQLinearKernel: | ||
| def __init__(self, quant_config: Optional["QuantizationConfig"] = None): | ||
| self.quant_config = quant_config | ||
|
|
||
| def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | ||
| layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False) | ||
| layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False) | ||
| layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False) | ||
|
|
||
| def apply( | ||
| self, | ||
| layer: torch.nn.Module, | ||
| x: torch.Tensor, | ||
| bias: Optional[torch.Tensor] = None, | ||
| ) -> torch.Tensor: | ||
| qweight = layer.qweight | ||
| scales = layer.scales | ||
| qzeros = layer.qzeros | ||
| pack_factor = self.quant_config.pack_factor | ||
| out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,) | ||
| reshaped_x = x.reshape(-1, x.shape[-1]) | ||
| out = awq_dequantize(qweight, scales, qzeros) | ||
| out = torch.matmul(reshaped_x, out) | ||
|
|
||
| if bias is not None: | ||
| out.add_(bias) | ||
| return out.reshape(out_shape) | ||
|
|
||
|
|
||
| class AWQMarlinLinearKernel: | ||
| def __init__(self, quant_config: Optional["QuantizationConfig"] = None): | ||
| self.quant_config = quant_config | ||
|
|
||
| def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | ||
| device = layer.qweight.device | ||
| layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False) | ||
| layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False) | ||
| layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False) | ||
|
|
||
| layer.workspace = marlin_make_workspace(device) | ||
|
|
||
| marlin_qweight = awq_marlin_repack( | ||
| layer.qweight, | ||
| size_k=layer.input_size_per_partition, | ||
| size_n=layer.output_size_per_partition, | ||
| num_bits=self.quant_config.quant_type.size_bits, | ||
| ) | ||
| replace_parameter(layer, "qweight", marlin_qweight) | ||
|
|
||
| marlin_scales = marlin_permute_scales( | ||
| layer.scales, | ||
| size_k=layer.input_size_per_partition, | ||
| size_n=layer.output_size_per_partition, | ||
| group_size=self.quant_config.group_size, | ||
| ) | ||
| replace_parameter(layer, "scales", marlin_scales) | ||
|
|
||
| marlin_zp = awq_to_marlin_zero_points( | ||
| layer.qzeros, | ||
| size_k=layer.num_groups, | ||
| size_n=layer.output_size_per_partition, | ||
| num_bits=self.quant_config.quant_type.size_bits, | ||
| ) | ||
| replace_parameter(layer, "qzeros", marlin_zp) | ||
|
|
||
| layer.g_idx = marlin_make_empty_g_idx(device) | ||
| layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) | ||
|
|
||
| def apply( | ||
| self, | ||
| layer: torch.nn.Module, | ||
| x: torch.Tensor, | ||
| bias: Optional[torch.Tensor] = None, | ||
| ) -> torch.Tensor: | ||
| return apply_awq_marlin_linear( | ||
| input=x, | ||
| weight=layer.qweight, | ||
| weight_scale=layer.scales, | ||
| weight_zp=layer.qzeros, | ||
| g_idx=layer.g_idx, | ||
| g_idx_sort_indices=layer.g_idx_sort_indices, | ||
| workspace=layer.workspace, | ||
| quant_type=self.quant_config.quant_type, | ||
| output_size_per_partition=layer.output_size_per_partition, | ||
| input_size_per_partition=layer.input_size_per_partition, | ||
| bias=bias, | ||
| ) | ||
|
|
||
|
|
||
| class AWQMoEKernel: | ||
| def __init__(self, quant_config: Optional["QuantizationConfig"] = None): | ||
| self.quant_config = quant_config | ||
| self.runner: Optional[MoeRunner] = None | ||
|
|
||
| def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | ||
| num_experts = layer.w13_qweight.shape[0] | ||
| device = layer.w13_qweight.device | ||
|
|
||
| layer.w13_g_idx_sort_indices = torch.nn.Parameter( | ||
| torch.empty((num_experts, 0), dtype=torch.int32, device=device), | ||
| requires_grad=False, | ||
| ) | ||
| layer.w2_g_idx_sort_indices = torch.nn.Parameter( | ||
| torch.empty((num_experts, 0), dtype=torch.int32, device=device), | ||
| requires_grad=False, | ||
| ) | ||
|
|
||
| marlin_w13_qweight = awq_marlin_moe_repack( | ||
| layer.w13_qweight, | ||
| layer.w13_g_idx_sort_indices, | ||
| size_k=layer.w13_qweight.shape[1], | ||
| size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor, | ||
| num_bits=self.quant_config.weight_bits, | ||
| ) | ||
| replace_parameter(layer, "w13_qweight", marlin_w13_qweight) | ||
|
|
||
| marlin_w2_qweight = awq_marlin_moe_repack( | ||
| layer.w2_qweight, | ||
| layer.w2_g_idx_sort_indices, | ||
| size_k=layer.w2_qweight.shape[1], | ||
| size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor, | ||
| num_bits=self.quant_config.weight_bits, | ||
| ) | ||
| replace_parameter(layer, "w2_qweight", marlin_w2_qweight) | ||
|
|
||
| marlin_w13_scales = marlin_moe_permute_scales( | ||
| s=layer.w13_scales, | ||
| size_k=layer.intermediate_size_per_partition, | ||
| size_n=layer.w13_scales.shape[2], | ||
| group_size=self.quant_config.group_size, | ||
| ) | ||
| replace_parameter(layer, "w13_scales", marlin_w13_scales) | ||
|
|
||
| marlin_w2_scales = marlin_moe_permute_scales( | ||
| s=layer.w2_scales, | ||
| size_k=layer.intermediate_size_per_partition, | ||
| size_n=layer.w2_scales.shape[2], | ||
| group_size=self.quant_config.group_size, | ||
| ) | ||
| replace_parameter(layer, "w2_scales", marlin_w2_scales) | ||
|
|
||
| marlin_w13_zp = moe_awq_to_marlin_zero_points( | ||
| layer.w13_qzeros, | ||
| size_k=layer.w13_qzeros.shape[1], | ||
| size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor, | ||
| num_bits=self.quant_config.weight_bits, | ||
| ) | ||
| replace_parameter(layer, "w13_qzeros", marlin_w13_zp) | ||
|
|
||
| marlin_w2_zp = moe_awq_to_marlin_zero_points( | ||
| layer.w2_qzeros, | ||
| size_k=layer.w2_qzeros.shape[1], | ||
| size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor, | ||
| num_bits=self.quant_config.weight_bits, | ||
| ) | ||
| replace_parameter(layer, "w2_qzeros", marlin_w2_zp) | ||
|
|
||
| def apply( | ||
| self, | ||
| layer: torch.nn.Module, | ||
| dispatch_output: "StandardDispatchOutput", | ||
| ) -> "CombineInput": | ||
| if self.runner is None: | ||
| raise RuntimeError("moe runner is not initialized") | ||
|
|
||
| quant_info = MarlinMoeQuantInfo( | ||
| w13_qweight=layer.w13_qweight, | ||
| w2_qweight=layer.w2_qweight, | ||
| w13_scales=layer.w13_scales, | ||
| w2_scales=layer.w2_scales, | ||
| w13_g_idx_sort_indices=layer.w13_g_idx_sort_indices, | ||
| w2_g_idx_sort_indices=layer.w2_g_idx_sort_indices, | ||
| w13_qzeros=layer.w13_qzeros, | ||
| w2_qzeros=layer.w2_qzeros, | ||
| weight_bits=self.quant_config.weight_bits, | ||
| ) | ||
| return self.runner.run(dispatch_output, quant_info) | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
size_kparameter formarlin_moe_permute_scalesappears to be incorrect forw13_scales. Thekdimension of thew13weight matrix ishidden_size, and the scales are grouped along this dimension. Therefore,size_kshould belayer.hidden_sizeinstead oflayer.intermediate_size_per_partition.