[4/N] Quantization Refactor: AWQ schemes and Kernel call and weight init split#21126
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the AWQ quantization framework by introducing a more organized and extensible architecture. The changes aim to improve code clarity and facilitate the integration of diverse hardware backends and quantization methods, moving towards a more unified and maintainable quantization system. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
|
We kept both AWQConfig and AWQMarlinConfig for now because awq and awq_marlin are still exposed as separate quantization entry points with distinct compatibility and fallback behavior. |
There was a problem hiding this comment.
Code Review
This pull request is a significant refactoring of the AWQ quantization logic, moving to a more modular scheme-based architecture. This is a great improvement for maintainability and extensibility. The changes also include splitting backend-specific kernels for GPU and NPU, which is a clean separation of concerns. I've found a critical bug in the GPU kernel logic and a few areas for improvement in the new NPU kernels and scheme definitions.
|
|
||
| marlin_w13_scales = marlin_moe_permute_scales( | ||
| s=layer.w13_scales, | ||
| size_k=layer.intermediate_size_per_partition, |
There was a problem hiding this comment.
The size_k parameter for marlin_moe_permute_scales appears to be incorrect for w13_scales. The k dimension of the w13 weight matrix is hidden_size, and the scales are grouped along this dimension. Therefore, size_k should be layer.hidden_size instead of layer.intermediate_size_per_partition.
| size_k=layer.intermediate_size_per_partition, | |
| size_k=layer.hidden_size, |
| qweight_tmp.bitwise_or_( | ||
| ((layer.qweight.data >> shift_num) * (2 ** (4 * i))) & (0xF << (4 * i)) | ||
| ) |
There was a problem hiding this comment.
The bitwise operation can be simplified for clarity and potentially better performance. The multiplication by a power of two can be replaced with a left bit shift. Also, the final bitwise AND is redundant since the shifted value is already a nibble.
| qweight_tmp.bitwise_or_( | |
| ((layer.qweight.data >> shift_num) * (2 ** (4 * i))) & (0xF << (4 * i)) | |
| ) | |
| qweight_tmp.bitwise_or_( | |
| ((layer.qweight.data >> shift_num) & 0xF) << (4 * i) | |
| ) |
| w13_qweight_tmp.bitwise_or_( | ||
| ((layer.w13_qweight.data >> shift_num) * (2 ** (4 * i))) | ||
| & (0xF << (4 * i)) | ||
| ) | ||
| w2_qweight_tmp.bitwise_or_( | ||
| ((layer.w2_qweight.data >> shift_num) * (2 ** (4 * i))) | ||
| & (0xF << (4 * i)) | ||
| ) |
There was a problem hiding this comment.
The bitwise operations can be simplified for clarity and potentially better performance. The multiplication by a power of two can be replaced with a left bit shift. Also, the final bitwise AND is redundant since the shifted value is already a nibble.
w13_qweight_tmp.bitwise_or_(
((layer.w13_qweight.data >> shift_num) & 0xF) << (4 * i)
)
w2_qweight_tmp.bitwise_or_(
((layer.w2_qweight.data >> shift_num) & 0xF) << (4 * i)
)| layer.register_parameter( | ||
| "w13_qzeros", torch.nn.Parameter(w13_qzeros_tmp, requires_grad=False) | ||
| ) | ||
| layer.register_parameter( | ||
| "w13_qweight", torch.nn.Parameter(w13_qweight_tmp, requires_grad=False) | ||
| ) | ||
| layer.register_parameter( | ||
| "w2_qzeros", torch.nn.Parameter(w2_qzeros_tmp, requires_grad=False) | ||
| ) | ||
| layer.register_parameter( | ||
| "w2_qweight", torch.nn.Parameter(w2_qweight_tmp, requires_grad=False) | ||
| ) |
There was a problem hiding this comment.
For consistency with other parts of the codebase (e.g., the GPU AWQ kernel) and to improve maintainability, it's better to use the replace_parameter utility function for replacing module parameters. You'll need to add from sglang.srt.layers.quantization.utils import replace_parameter to the imports, and then replace these layer.register_parameter calls with replace_parameter(layer, "param_name", new_tensor).
| self.quant_config = quant_config | ||
| self.kernel = AWQMoEKernel(quant_config) | ||
| if self.quant_config.weight_bits != 4: | ||
| raise ValueError("AWQMoEMethod only supports 4bit now.") |
There was a problem hiding this comment.
| top_k=topk_ids.shape[1], | ||
| use_wna16=True, | ||
| ) | ||
| return StandardCombineInput(hidden_states=output) |
There was a problem hiding this comment.
Since we are not gated behind specific kernel implementations, could you look if it is possible to call NPUW4A16Int4DynamicMoEMethod as a kernel here?
Example from one of our MoE refactoring PRs: https://github.com/sgl-project/sglang/pull/17361/changes#diff-34cc9aacc2ffaa0ad8351300aad66099bcbc2451d9a0a2c089aab5926d4f5e01
It should work for both apply and process_weights.
| self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig | ||
| ): | ||
| self.moe_runner_config = moe_runner_config | ||
| self.kernel.moe_runner_config = moe_runner_config |
There was a problem hiding this comment.
Can we merge awq_moe.py file and this file together like it's done for Linear schemes in awq_w4a16.py?
b8zhong
left a comment
There was a problem hiding this comment.
Hi, can you leave the GPU code outside of hardware backend? Since I think this is designed for other hardware backends, if I'm not misunderstanding, right?
Hi! From our discussion with @rainj-me here we thought that it would be ok to move kernel related files from quantization folder into hardware_backend structure and create one for gpu. |
|
@TamirBaydasov Thanks. I saw your comment, but I also saw #15194 (comment). |
This comment was related to moving all cuda kernels into hardware_backend structure. That is, creating a similar structure to NPU with not only quantization kernels being present there. |
| _is_cuda = is_cuda() | ||
| _is_hip = is_hip() | ||
| _is_xpu = is_xpu() |
There was a problem hiding this comment.
if you've a seperate gpu folder, i think we should reduce or delete all is_xxx code, right?
| qweight_tmp.bitwise_or_( | ||
| ((layer.qweight.data >> shift_num) * (2 ** (4 * i))) & (0xF << (4 * i)) | ||
| ) |
| self.quant_config = quant_config | ||
| self.kernel = AWQMoEKernel(quant_config) | ||
| if self.quant_config.weight_bits != 4: | ||
| raise ValueError("AWQMoEMethod only supports 4bit now.") |
| try: | ||
| from sglang.jit_kernel.awq_dequantize import awq_dequantize | ||
| from sglang.jit_kernel.awq_marlin_repack import ( | ||
| awq_marlin_moe_repack, | ||
| awq_marlin_repack, | ||
| ) | ||
| from sglang.srt.utils.custom_op import register_custom_op_from_extern | ||
|
|
||
| awq_dequantize = register_custom_op_from_extern( | ||
| awq_dequantize, | ||
| fake_impl=lambda qweight, scales, qzeros: qweight.new_empty( | ||
| qweight.shape[:-1] + (qweight.shape[-1] * 8,), dtype=scales.dtype | ||
| ), | ||
| ) | ||
| except ImportError: | ||
| try: | ||
| from sglang.srt.layers.quantization.awq.awq_triton import ( | ||
| awq_dequantize_triton as awq_dequantize, | ||
| ) | ||
| except ImportError: | ||
| try: | ||
| from sgl_kernel import awq_dequantize | ||
| except ImportError: | ||
| pass |
There was a problem hiding this comment.
I believe there is now a regression for XPU here? Since it will go directly to triton
|
|
||
| class AWQAscendMoEScheme(AWQMoEScheme): | ||
| def __init__(self, quant_config: "AWQConfig"): | ||
| super().__init__(quant_config) |
There was a problem hiding this comment.
super().__init__(quant_config) has AWQMoEKernel for GPU, which transitively imports MarlinMoeQuantInfo and marlin_utils. Not needed for NPU.
Suggest skipping the parent init or refactoring to a _init_kernel() hook:
class AWQAscendMoEScheme(AWQMoEScheme):
def __init__(self, quant_config: "AWQConfig"):
# skip AWQMoEScheme.__init__
from sglang.srt.hardware_backend.npu.quantization.awq_kernels import (
AWQAscendMoEKernel,
)
self.quant_config = quant_config
if self.quant_config.weight_bits != 4:
raise ValueError("AWQAscendMoEScheme only supports 4bit now.")
self.kernel = AWQAscendMoEKernel(quant_config)Let's talk about this a bit further as I think the cleaner long-term fix is making self.kernel come from a platform factory (see plugin integration note on awq.py), at which point this subclass disappears entirely.
| class AWQLinearIntelAMXMethod(AWQLinearMethod): | ||
| """Linear method for AWQ on Intel CPU with AMX.""" | ||
| def __init__(self, quant_config: "AWQConfig"): | ||
| self.quant_config = quant_config |
There was a problem hiding this comment.
AWQIntelAMXLinearScheme overrides __init__ but doesn't call super().__init__() and doesn't set self.kernel. Follows other comment but lets clean this up or make it less brittle?
| "CompressedTensorsLinearMethod", | ||
| "AWQMarlinLinearMethod", | ||
| "AWQLinearMethod", | ||
| "AWQLinearAscendMethod", |
There was a problem hiding this comment.
AWQLinearAscendMethod was deleted in this PR --> Ascend now goes through the unified AWQLinearMethod + AWQAscendLinearScheme. Since WEIGHT_LOADER_V2_SUPPORTED is matched by class name string, we should rm?
| def get_linear_scheme(self, layer: torch.nn.Module): | ||
| assert isinstance(layer, LinearBase) | ||
| if _is_npu: | ||
| return AWQAscendLinearScheme(self) | ||
| return AWQLinearScheme(self) | ||
|
|
||
| def get_moe_scheme(self, layer: torch.nn.Module): | ||
| from sglang.srt.layers.moe.fused_moe_triton import FusedMoE | ||
|
|
||
| assert isinstance(layer, FusedMoE) | ||
| if _is_npu: | ||
| return AWQAscendMoEScheme(self) | ||
| raise NotImplementedError("AWQConfig only supports MoE scheme on NPU.") | ||
|
|
There was a problem hiding this comment.
Plugin-integration note (#21388 follow-up).
TODO needed here for Multiplatform plugin. (even if you just mark with need to integrate current_platform.is_out_of_tree
something like:
def get_linear_scheme(self, layer: torch.nn.Module):
assert isinstance(layer, LinearBase)
from sglang.srt.platforms import current_platform
cls = current_platform.get_awq_linear_scheme_cls()
if cls is not None:
return cls(self)
return AWQLinearScheme(self) # in-tree CUDA default
def get_moe_scheme(self, layer: torch.nn.Module):
from sglang.srt.platforms import current_platform
cls = current_platform.get_awq_moe_scheme_cls()
if cls is None:
raise NotImplementedError(
f"AWQ MoE not provided by platform {current_platform.get_dispatch_key_name()!r}."
)
return cls(self)With SRTPlatform extended to expose get_awq_linear_scheme_cls() / get_awq_moe_scheme_cls() / get_awq_marlin_linear_scheme_cls() (returning None by default; concrete platforms override). This matches how PR #21388 already exposes get_mha_kv_pool_cls(), get_graph_runner_cls(), etc.
Another option is to push the platform factory down to the .kernel layer (AWQLinearScheme.__init__ calls current_platform.get_awq_linear_kernel_cls()). This would eliminate AWQAscendLinearScheme / AWQIntelAMXLinearScheme as subclasses entirely since they become kernel registrations on the OOT platform plugin. Also fixes the super().__init__ side-effect issues I mentioned earlier
There was a problem hiding this comment.
Thanks for the review. I addressed the AWQ scheme issues:
- Restored the XPU AWQ dequant path to use
sgl_kernel.awq_dequantizeinstead of falling through to Triton. - Refactored AWQ Linear/MoE schemes to use
_init_kernel()hooks, so Ascend no longer initializes the default GPU/Marlin kernel before replacing it. - Updated the CPU AMX AWQ path to use CPU-specific kernel objects behind the scheme, avoiding the brittle subclass-without-
super().__init__()pattern. - Removed the stale
AWQLinearAscendMethodentry fromWEIGHT_LOADER_V2_SUPPORTED. - Added a TODO for moving AWQ scheme/kernel selection into the multiplatform plugin factory once quantization hooks are available.
| return AWQAscendLinearScheme(self) | ||
| return AWQLinearScheme(self) | ||
|
|
||
| def get_moe_scheme(self, layer: torch.nn.Module): |
There was a problem hiding this comment.
nit: get_moe_scheme raising NotImplementedError for non-NPU is unreachable today (caller returns None for FusedMoE on non-NPU before consulting it). Either remove the raise or document the intent.
|
please confirm this is flasky ut or code accuracy issue |
…/awq-scheme-refactor # Conflicts: # python/sglang/srt/layers/moe/moe_runner/triton_utils/fused_moe.py
…c/awq-scheme-refactor
|
I merged it since several committer already reviewed and we confirmed that only one GPU failed CI is unrelated to our change |
Motivation
Add schemes to awq instead of storing all classes in a single file, and split kernel call and weight init. Follow up to #17503.
Images and motivation for this PR can be viewed in our roadmap: #15194.
Modifications
Refactored AWQ to align with the scheme-based quantization structure used by modelslim and compressed_tensors.
Moved AWQ implementations out of the monolithic quantization/awq.py into the new package under quantization/awq/, with scheme implementations split into quantization/awq/schemes/.
Added get_linear_scheme and get_moe_scheme to awq/awq.py so linear and MoE paths select concrete schemes explicitly.
Unified AWQ quant methods into thin wrappers that delegate to layer.scheme, matching the compressed_tensors call pattern.
Moved AWQ Triton helpers into quantization/awq/awq_triton.py and removed the old top-level quantization/awq_triton.py.
Split backend-specific kernel logic into:
This keeps awq.py focused on config, method dispatch, and scheme selection, while concrete weight handling and execution live in schemes and backend kernels.
Accuracy Tests
GPU tests:


NPU tests:
Benchmarking and Profiling
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci