-
Notifications
You must be signed in to change notification settings - Fork 5.5k
Add AWQ quantization support for NPU. #10158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2ca9cb7
c74b35b
b009a54
b865bce
6857687
4515d25
e8cb970
281d084
a2c4e02
7c30f09
0744852
a32ed84
a3d9077
98f63d0
7edc183
44732a9
a2770f2
436e141
eda1ed7
e9530b2
1e38fe5
90ae315
21ce362
709f011
b289972
af799de
2c00cf3
31c6ec7
6d62983
599d004
385e194
7570fd2
168faac
71a8ec2
8fc5424
de5e2d2
0b24a48
0bdb55b
630f76d
ed55e68
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -31,6 +31,7 @@ | |||||||||||||||
| ) | ||||||||||||||||
| from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod | ||||||||||||||||
| from sglang.srt.layers.quantization.utils import get_scalar_types, replace_parameter | ||||||||||||||||
| from sglang.srt.layers.quantization.w8a8_int8 import npu_fused_experts | ||||||||||||||||
|
|
||||||||||||||||
| if TYPE_CHECKING: | ||||||||||||||||
| from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig | ||||||||||||||||
|
|
@@ -39,11 +40,16 @@ | |||||||||||||||
| CombineInput, | ||||||||||||||||
| ) | ||||||||||||||||
|
|
||||||||||||||||
| from sglang.srt.utils import is_cuda, is_hip, is_xpu | ||||||||||||||||
| from sglang.srt.utils import is_cuda, is_hip, is_npu, is_xpu | ||||||||||||||||
|
|
||||||||||||||||
| _is_cuda = is_cuda() | ||||||||||||||||
| _is_hip = is_hip() | ||||||||||||||||
| _is_xpu = is_xpu() | ||||||||||||||||
| _is_npu = is_npu() | ||||||||||||||||
|
|
||||||||||||||||
| if _is_npu: | ||||||||||||||||
| import torch_npu | ||||||||||||||||
|
|
||||||||||||||||
| if _is_cuda: | ||||||||||||||||
| from sgl_kernel import ( | ||||||||||||||||
| awq_dequantize, | ||||||||||||||||
|
|
@@ -117,12 +123,17 @@ def get_name(self) -> str: | |||||||||||||||
| return "awq" | ||||||||||||||||
|
|
||||||||||||||||
| def get_supported_act_dtypes(self) -> List[torch.dtype]: | ||||||||||||||||
| return [torch.half] | ||||||||||||||||
| return [torch.float16] if not _is_npu else [torch.float16, torch.bfloat16] | ||||||||||||||||
|
|
||||||||||||||||
| @classmethod | ||||||||||||||||
| def get_min_capability(cls) -> int: | ||||||||||||||||
| # The AWQ kernel only supports Turing or newer GPUs. | ||||||||||||||||
| return 75 | ||||||||||||||||
| if _is_npu: | ||||||||||||||||
| raise NotImplementedError( | ||||||||||||||||
| 'NPU hardware does not support "get_min_capability" feature.' | ||||||||||||||||
| ) | ||||||||||||||||
| else: | ||||||||||||||||
| return 75 | ||||||||||||||||
|
|
||||||||||||||||
| @staticmethod | ||||||||||||||||
| def get_config_filenames() -> List[str]: | ||||||||||||||||
|
|
@@ -146,6 +157,16 @@ def get_quant_method( | |||||||||||||||
| self, layer: torch.nn.Module, prefix: str | ||||||||||||||||
| ) -> Optional[LinearMethodBase]: | ||||||||||||||||
| from sglang.srt.layers.linear import LinearBase | ||||||||||||||||
| from sglang.srt.layers.moe.fused_moe_triton import FusedMoE | ||||||||||||||||
|
|
||||||||||||||||
| if _is_npu: | ||||||||||||||||
| if isinstance(layer, LinearBase): | ||||||||||||||||
| if is_layer_skipped_awq(prefix, self.modules_to_not_convert): | ||||||||||||||||
| return UnquantizedLinearMethod() | ||||||||||||||||
| return AWQLinearAscendMethod(self) | ||||||||||||||||
| elif isinstance(layer, FusedMoE): | ||||||||||||||||
| return AWQMoEAscendMethod(self) | ||||||||||||||||
| return None | ||||||||||||||||
|
|
||||||||||||||||
| if isinstance(layer, LinearBase): | ||||||||||||||||
| if is_layer_skipped_awq(prefix, self.modules_to_not_convert): | ||||||||||||||||
|
|
@@ -575,6 +596,64 @@ def apply( | |||||||||||||||
| ) | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| class AWQLinearAscendMethod(AWQLinearMethod): | ||||||||||||||||
| """Linear method for AWQ on Ascend. | ||||||||||||||||
|
|
||||||||||||||||
| Args: | ||||||||||||||||
| quant_config: The AWQ quantization config. | ||||||||||||||||
| """ | ||||||||||||||||
|
|
||||||||||||||||
| def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | ||||||||||||||||
| layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False) | ||||||||||||||||
| qweight_tmp = torch.zeros_like(layer.qweight.data) | ||||||||||||||||
| qzeros_tmp = layer.qzeros.data | ||||||||||||||||
| qzeros_list = [] | ||||||||||||||||
| shifts = [0, 4, 1, 5, 2, 6, 3, 7] | ||||||||||||||||
|
|
||||||||||||||||
| for i in range(0, self.quant_config.pack_factor): | ||||||||||||||||
| shift_num = shifts[i] * 4 | ||||||||||||||||
| qzeros_list.append((qzeros_tmp.reshape(-1, 1) >> shift_num) & 0xF) | ||||||||||||||||
| qweight_tmp.bitwise_or_( | ||||||||||||||||
| ((layer.qweight.data >> shift_num) * (2 ** (4 * i))) & (0xF << (4 * i)) | ||||||||||||||||
| ) | ||||||||||||||||
|
|
||||||||||||||||
| qweight_tmp.bitwise_xor_(0x88888888) | ||||||||||||||||
|
|
||||||||||||||||
| qzeros_tmp = torch.cat(qzeros_list, dim=-1).reshape(qzeros_tmp.shape[0], -1) | ||||||||||||||||
| qzeros_tmp = -(qzeros_tmp - 8) | ||||||||||||||||
| qzeros_tmp = qzeros_tmp.to(layer.scales.data.dtype) | ||||||||||||||||
|
|
||||||||||||||||
| layer.qzeros = torch.nn.Parameter(qzeros_tmp, requires_grad=False) | ||||||||||||||||
| layer.qweight = torch.nn.Parameter(qweight_tmp, requires_grad=False) | ||||||||||||||||
|
|
||||||||||||||||
| def apply( | ||||||||||||||||
| self, | ||||||||||||||||
| layer: torch.nn.Module, | ||||||||||||||||
| x: torch.Tensor, | ||||||||||||||||
| bias: Optional[torch.Tensor] = None, | ||||||||||||||||
| ) -> torch.Tensor: | ||||||||||||||||
| qweight = layer.qweight | ||||||||||||||||
| scales = layer.scales | ||||||||||||||||
| qzeros = layer.qzeros | ||||||||||||||||
| pack_factor = self.quant_config.pack_factor | ||||||||||||||||
| out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,) | ||||||||||||||||
| reshaped_x = x.reshape(-1, x.shape[-1]) | ||||||||||||||||
|
|
||||||||||||||||
| if bias is not None and bias.dtype == torch.bfloat16: | ||||||||||||||||
| bias = bias.float() | ||||||||||||||||
|
|
||||||||||||||||
| out = torch_npu.npu_weight_quant_batchmatmul( | ||||||||||||||||
| reshaped_x, | ||||||||||||||||
| qweight, | ||||||||||||||||
| antiquant_scale=scales, | ||||||||||||||||
| antiquant_offset=qzeros, | ||||||||||||||||
| antiquant_group_size=self.quant_config.group_size, | ||||||||||||||||
| bias=bias, | ||||||||||||||||
| ) | ||||||||||||||||
|
|
||||||||||||||||
| return out.reshape(out_shape) | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| class AWQMoEMethod(FusedMoEMethodBase): | ||||||||||||||||
|
|
||||||||||||||||
| def __init__(self, quant_config: AWQMarlinConfig): | ||||||||||||||||
|
|
@@ -677,7 +756,8 @@ def create_weights( | |||||||||||||||
| set_weight_attrs(w2_qzeros, extra_weight_attrs) | ||||||||||||||||
|
|
||||||||||||||||
| device = layer.w13_qweight.device | ||||||||||||||||
| layer.workspace = marlin_make_workspace(device, 4) | ||||||||||||||||
| if not _is_npu: | ||||||||||||||||
| layer.workspace = marlin_make_workspace(device, 4) | ||||||||||||||||
|
|
||||||||||||||||
| def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | ||||||||||||||||
| num_experts = layer.w13_qweight.shape[0] | ||||||||||||||||
|
|
@@ -785,3 +865,95 @@ def apply( | |||||||||||||||
| num_bits=self.quant_config.weight_bits, | ||||||||||||||||
| ).to(orig_dtype) | ||||||||||||||||
| return StandardCombineInput(hidden_states=output) | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| class AWQMoEAscendMethod(AWQMoEMethod): | ||||||||||||||||
| def __init__(self, quant_config: AWQConfig): | ||||||||||||||||
| self.quant_config = quant_config | ||||||||||||||||
|
Comment on lines
+871
to
+872
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The Given that
Suggested change
|
||||||||||||||||
|
|
||||||||||||||||
| def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | ||||||||||||||||
| w13_qweight_tmp = torch.zeros_like(layer.w13_qweight.data) | ||||||||||||||||
| w2_qweight_tmp = torch.zeros_like(layer.w2_qweight.data) | ||||||||||||||||
| w13_qzeros_list = [] | ||||||||||||||||
| w2_qzeros_list = [] | ||||||||||||||||
| shifts = [0, 4, 1, 5, 2, 6, 3, 7] | ||||||||||||||||
| for i in range(0, self.quant_config.pack_factor): | ||||||||||||||||
| shift_num = shifts[i] * 4 | ||||||||||||||||
| w13_qzeros_list.append( | ||||||||||||||||
| (layer.w13_qzeros.data.reshape(-1, 1) >> shift_num) & 0xF | ||||||||||||||||
| ) | ||||||||||||||||
| w2_qzeros_list.append( | ||||||||||||||||
| (layer.w2_qzeros.data.reshape(-1, 1) >> shift_num) & 0xF | ||||||||||||||||
| ) | ||||||||||||||||
| w13_qweight_tmp.bitwise_or_( | ||||||||||||||||
| ((layer.w13_qweight.data >> shift_num) * (2 ** (4 * i))) | ||||||||||||||||
| & (0xF << (4 * i)) | ||||||||||||||||
| ) | ||||||||||||||||
| w2_qweight_tmp.bitwise_or_( | ||||||||||||||||
| ((layer.w2_qweight.data >> shift_num) * (2 ** (4 * i))) | ||||||||||||||||
| & (0xF << (4 * i)) | ||||||||||||||||
| ) | ||||||||||||||||
|
Comment on lines
+888
to
+895
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to w13_qweight_tmp.bitwise_or_(
(((layer.w13_qweight.data >> shift_num) & 0xF) << (4 * i))
)
w2_qweight_tmp.bitwise_or_(
(((layer.w2_qweight.data >> shift_num) & 0xF) << (4 * i))
) |
||||||||||||||||
|
|
||||||||||||||||
| w13_qweight_tmp.bitwise_xor_(0x88888888) | ||||||||||||||||
| w2_qweight_tmp.bitwise_xor_(0x88888888) | ||||||||||||||||
|
|
||||||||||||||||
| w13_qzeros_tmp = torch.cat(w13_qzeros_list, dim=-1).reshape( | ||||||||||||||||
| layer.w13_qzeros.shape[0], layer.w13_qzeros.shape[1], -1 | ||||||||||||||||
| ) | ||||||||||||||||
| w13_qzeros_tmp = -(w13_qzeros_tmp - 8) | ||||||||||||||||
| w13_qzeros_tmp = w13_qzeros_tmp.to(layer.w13_scales.data.dtype) | ||||||||||||||||
| w2_qzeros_tmp = torch.cat(w2_qzeros_list, dim=-1).reshape( | ||||||||||||||||
| layer.w2_qzeros.shape[0], layer.w2_qzeros.shape[1], -1 | ||||||||||||||||
| ) | ||||||||||||||||
| w2_qzeros_tmp = -(w2_qzeros_tmp - 8) | ||||||||||||||||
| w2_qzeros_tmp = w2_qzeros_tmp.to(layer.w2_scales.data.dtype) | ||||||||||||||||
|
|
||||||||||||||||
| layer.register_parameter( | ||||||||||||||||
| "w13_qzeros", torch.nn.Parameter(w13_qzeros_tmp, requires_grad=False) | ||||||||||||||||
| ) | ||||||||||||||||
| layer.register_parameter( | ||||||||||||||||
| "w13_qweight", torch.nn.Parameter(w13_qweight_tmp, requires_grad=False) | ||||||||||||||||
| ) | ||||||||||||||||
| layer.register_parameter( | ||||||||||||||||
| "w2_qzeros", torch.nn.Parameter(w2_qzeros_tmp, requires_grad=False) | ||||||||||||||||
| ) | ||||||||||||||||
| layer.register_parameter( | ||||||||||||||||
| "w2_qweight", torch.nn.Parameter(w2_qweight_tmp, requires_grad=False) | ||||||||||||||||
| ) | ||||||||||||||||
|
|
||||||||||||||||
| def create_moe_runner( | ||||||||||||||||
| self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig | ||||||||||||||||
| ): | ||||||||||||||||
| self.moe_runner_config = moe_runner_config | ||||||||||||||||
|
|
||||||||||||||||
| def apply( | ||||||||||||||||
| self, | ||||||||||||||||
| layer: torch.nn.Module, | ||||||||||||||||
| dispatch_output: StandardDispatchOutput, | ||||||||||||||||
| ) -> torch.Tensor: | ||||||||||||||||
| from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput | ||||||||||||||||
|
|
||||||||||||||||
| assert ( | ||||||||||||||||
| self.moe_runner_config.activation == "silu" | ||||||||||||||||
| ), "Only SiLU activation is supported." | ||||||||||||||||
|
|
||||||||||||||||
| x = dispatch_output.hidden_states | ||||||||||||||||
| topk_output = dispatch_output.topk_output | ||||||||||||||||
|
|
||||||||||||||||
| topk_weights, topk_ids, _ = topk_output | ||||||||||||||||
| topk_ids = topk_ids.to(torch.int32) | ||||||||||||||||
| topk_weights = topk_weights.to(x.dtype) | ||||||||||||||||
| output = npu_fused_experts( | ||||||||||||||||
| hidden_states=x, | ||||||||||||||||
| w13=layer.w13_qweight, | ||||||||||||||||
| w13_scale=layer.w13_scales, | ||||||||||||||||
| w13_offset=layer.w13_qzeros, | ||||||||||||||||
| w2=layer.w2_qweight, | ||||||||||||||||
| w2_scale=layer.w2_scales, | ||||||||||||||||
| w2_offset=layer.w2_qzeros, | ||||||||||||||||
| topk_weights=topk_weights, | ||||||||||||||||
| topk_ids=topk_ids, | ||||||||||||||||
| top_k=topk_ids.shape[1], | ||||||||||||||||
| use_wna16=True, | ||||||||||||||||
| ) | ||||||||||||||||
| return StandardCombineInput(hidden_states=output) | ||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The bitwise operation used for repacking weights is functionally correct but unnecessarily complex and hard to read. Using
(2 ** (4 * i))for left-shifting and then masking can be simplified. A more direct and readable approach is to first mask the desired nibble with& 0xFand then shift it to its new position. This improves code clarity and maintainability.