From 2ca9cb7ff67a822fc9b57b000488d382f2f86248 Mon Sep 17 00:00:00 2001 From: Alisehen <814073252@qq.com> Date: Mon, 8 Sep 2025 07:57:06 +0000 Subject: [PATCH 01/16] add awq quantization to ascend backend --- python/sglang/srt/layers/linear.py | 1 + python/sglang/srt/layers/quantization/awq.py | 227 +++++++++++++++++- .../srt/layers/quantization/awq_triton.py | 19 ++ python/sglang/srt/model_loader/loader.py | 2 + python/sglang/srt/models/deepseek_v2.py | 4 +- python/sglang/srt/utils.py | 2 + 6 files changed, 250 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 47dfc7324fc0..f783d6b138d7 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -45,6 +45,7 @@ "CompressedTensorsLinearMethod", "AWQMarlinLinearMethod", "AWQLinearMethod", + "AWQLinearAscendMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod", "BlockInt8LinearMethod", diff --git a/python/sglang/srt/layers/quantization/awq.py b/python/sglang/srt/layers/quantization/awq.py index 9cba60c2b532..e3b082ef3b69 100644 --- a/python/sglang/srt/layers/quantization/awq.py +++ b/python/sglang/srt/layers/quantization/awq.py @@ -39,10 +39,15 @@ CombineInput, ) -from sglang.srt.utils import is_cuda, is_hip +from sglang.srt.utils import is_cuda, is_hip, is_npu _is_cuda = is_cuda() _is_hip = is_hip() +_is_npu = is_npu() + +if _is_npu: + import torch_npu + if _is_cuda: from sgl_kernel import ( awq_dequantize, @@ -112,12 +117,21 @@ def get_name(self) -> str: return "awq" def get_supported_act_dtypes(self) -> List[torch.dtype]: - return [torch.half] + return ( + [torch.float16] + if not _is_npu + else [torch.float16, torch.bfloat16] + ) @classmethod def get_min_capability(cls) -> int: # The AWQ kernel only supports Turing or newer GPUs. - return 75 + if _is_npu: + raise NotImplementedError( + 'NPU hardware does not support "get_min_capability" feature.' + ) + else: + return 75 @staticmethod def get_config_filenames() -> List[str]: @@ -141,7 +155,17 @@ def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> Optional[LinearMethodBase]: from sglang.srt.layers.linear import LinearBase + from sglang.srt.layers.moe.fused_moe_triton import FusedMoE + if _is_npu: + if isinstance(layer, LinearBase): + if is_layer_skipped_awq(prefix, self.modules_to_not_convert): + return UnquantizedLinearMethod() + return AWQLinearAscendMethod(self) + elif isinstance(layer, FusedMoE): + return AWQMoEAscendMethod(self) + return None + if isinstance(layer, LinearBase): if is_layer_skipped_awq(prefix, self.modules_to_not_convert): return UnquantizedLinearMethod() @@ -569,7 +593,61 @@ def apply( bias=bias, ) +class AWQLinearAscendMethod(AWQLinearMethod): + """Linear method for AWQ on Ascend. + + Args: + quant_config: The AWQ quantization config. + """ + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False) + qweight_tmp = torch.zeros_like(layer.qweight.data) + qzeros_tmp = layer.qzeros.data + qzeros_list = [] + shifts = [0, 4, 1, 5, 2, 6, 3, 7] + + for i in range(0,self.quant_config.pack_factor): + shift_num = shifts[i] * 4 + qzeros_list.append((qzeros_tmp.reshape(-1, 1) >> shift_num) & 0xF) + qweight_tmp.bitwise_or_(((layer.qweight.data >> shift_num) * (2 ** (4*i))) & (0xF << (4*i))) + + qweight_tmp.bitwise_xor_(0x88888888) + + qzeros_tmp = torch.cat(qzeros_list, dim=-1).reshape(qzeros_tmp.shape[0], -1) + qzeros_tmp = -(qzeros_tmp - 8) + qzeros_tmp = qzeros_tmp.to(layer.scales.data.dtype) + layer.qzeros = torch.nn.Parameter(qzeros_tmp, requires_grad=False) + layer.qweight = torch.nn.Parameter(qweight_tmp, requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qweight = layer.qweight + scales = layer.scales + qzeros = layer.qzeros + pack_factor = self.quant_config.pack_factor + out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,) + reshaped_x = x.reshape(-1, x.shape[-1]) + + if bias is not None and bias.dtype == torch.bfloat16: + bias = bias.float() + + out = torch_npu.npu_weight_quant_batchmatmul( + reshaped_x, + qweight, + antiquant_scale=scales, + antiquant_offset=qzeros, + antiquant_group_size=self.quant_config.group_size, + bias=bias, + ) + + return out.reshape(out_shape) + class AWQMoEMethod(FusedMoEMethodBase): def __init__(self, quant_config: AWQMarlinConfig): @@ -672,7 +750,8 @@ def create_weights( set_weight_attrs(w2_qzeros, extra_weight_attrs) device = layer.w13_qweight.device - layer.workspace = marlin_make_workspace(device, 4) + if not _is_npu: + layer.workspace = marlin_make_workspace(device, 4) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: num_experts = layer.w13_qweight.shape[0] @@ -780,3 +859,143 @@ def apply( num_bits=self.quant_config.weight_bits, ).to(orig_dtype) return StandardCombineInput(hidden_states=output) + +def npu_fused_experts( + hidden_states: torch.Tensor, + w13: torch.Tensor, + w13_scale: torch.Tensor, + w13_offset: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + w2_offset: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, +): + original_shape = hidden_states.shape + original_dtype = hidden_states.dtype + scale_dtype = original_dtype if original_dtype == torch.bfloat16 else torch.float32 + if len(original_shape) == 3: + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + num_tokens = hidden_states.shape[0] + num_experts = w13.shape[0] + row_idx_len = num_tokens * top_k + row_idx = ( + torch.arange(0, row_idx_len, dtype=torch.int32, device=topk_weights.device) + .view(top_k, -1) + .permute(1, 0) + .contiguous() + ) + hidden_states, expanded_row_idx, expanded_expert_idx = ( + torch_npu.npu_moe_init_routing( + hidden_states, row_idx=row_idx, expert_idx=topk_ids, active_num=num_tokens + ) + ) + expert_tokens = torch_npu.npu_moe_compute_expert_tokens( + expanded_expert_idx, num_experts + ) + expert_tokens = expert_tokens.to(torch.int64) + # gmm1: gate_up_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w13], + antiquant_scale=[w13_scale], + antiquant_offset=[w13_offset], + split_item=2, + group_list_type=0, + group_type=0, + group_list=expert_tokens, + output_dtype=original_dtype, + )[0] + # act_fn: swiglu + hidden_states = torch_npu.npu_swiglu(hidden_states) + # gmm2: down_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w2], + antiquant_scale=[w2_scale], + antiquant_offset=[w2_offset], + split_item=2, + group_list_type=0, + group_type=0, + group_list=expert_tokens, + output_dtype=original_dtype, + )[0] + + final_hidden_states = torch_npu.npu_moe_finalize_routing( + hidden_states, + skip1=None, + skip2=None, + bias=None, + scales=topk_weights, + expanded_src_to_dst_row=expanded_row_idx, + export_for_source_row=topk_ids, + ) + if len(original_shape) == 3: + final_hidden_states = final_hidden_states.view(original_shape) + return final_hidden_states + +class AWQMoEAscendMethod(AWQMoEMethod): + def __init__(self, quant_config: AWQConfig): + self.quant_config = quant_config + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + w13_qweight_tmp = torch.zeros_like(layer.w13_qweight.data) + w2_qweight_tmp = torch.zeros_like(layer.w2_qweight.data) + w13_qzeros_list = [] + w2_qzeros_list = [] + shifts = [0, 4, 1, 5, 2, 6, 3, 7] + for i in range(0,self.quant_config.pack_factor): + shift_num = shifts[i] * 4 + w13_qzeros_list.append((layer.w13_qzeros.data.reshape(-1, 1) >> shift_num) & 0xF) + w2_qzeros_list.append((layer.w2_qzeros.data.reshape(-1, 1) >> shift_num) & 0xF) + w13_qweight_tmp.bitwise_or_(((layer.w13_qweight.data >> shift_num) * (2 ** (4*i))) & (0xF << (4*i))) + w2_qweight_tmp.bitwise_or_(((layer.w2_qweight.data >> shift_num) * (2 ** (4*i))) & (0xF << (4*i))) + + w13_qweight_tmp.bitwise_xor_(0x88888888) + w2_qweight_tmp.bitwise_xor_(0x88888888) + + w13_qzeros_tmp = torch.cat(w13_qzeros_list, dim=-1).reshape(layer.w13_qzeros.shape[0], layer.w13_qzeros.shape[1], -1) + w13_qzeros_tmp = -(w13_qzeros_tmp - 8) + w13_qzeros_tmp = w13_qzeros_tmp.to(layer.w13_scales.data.dtype) + w2_qzeros_tmp = torch.cat(w2_qzeros_list, dim=-1).reshape(layer.w2_qzeros.shape[0], layer.w2_qzeros.shape[1], -1) + w2_qzeros_tmp = -(w2_qzeros_tmp - 8) + w2_qzeros_tmp = w2_qzeros_tmp.to(layer.w2_scales.data.dtype) + + layer.register_parameter("w13_qzeros", torch.nn.Parameter(w13_qzeros_tmp, requires_grad=False)) + layer.register_parameter("w13_qweight", torch.nn.Parameter(w13_qweight_tmp, requires_grad=False)) + layer.register_parameter("w2_qzeros", torch.nn.Parameter(w2_qzeros_tmp, requires_grad=False)) + layer.register_parameter("w2_qweight", torch.nn.Parameter(w2_qweight_tmp, requires_grad=False)) + + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + + def apply( + self, + layer: torch.nn.Module, + dispatch_output: StandardDispatchOutput, + ) -> torch.Tensor: + assert ( + self.moe_runner_config.activation == "silu" + ), "Only SiLU activation is supported." + + x = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output + + topk_weights, topk_ids, _ = topk_output + topk_ids = topk_ids.to(torch.int32) + topk_weights = topk_weights.to(x.dtype) + return npu_fused_experts( + hidden_states=x, + w13=layer.w13_qweight, + w13_scale=layer.w13_scales, + w13_offset=layer.w13_qzeros, + w2=layer.w2_qweight, + w2_scale=layer.w2_scales, + w2_offset=layer.w2_qzeros, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=topk_ids.shape[1], + ) \ No newline at end of file diff --git a/python/sglang/srt/layers/quantization/awq_triton.py b/python/sglang/srt/layers/quantization/awq_triton.py index 13352efdb650..1c420a05b281 100644 --- a/python/sglang/srt/layers/quantization/awq_triton.py +++ b/python/sglang/srt/layers/quantization/awq_triton.py @@ -337,3 +337,22 @@ def awq_gemm_triton( result = result.sum(0) return result + +def awq_dequantize_decomposition( + qweight: torch.Tensor, + scales: torch.Tensor, + zeros: torch.Tensor, +) -> torch.Tensor: + qweight_tmp = qweight + qzeros_tmp = zeros + qweight_list = [] + qzeros_list = [] + shifts = [0, 4, 1, 5, 2, 6, 3, 7] + for i in range(0, 8): + shift_num = shifts[i] * 4 + qzeros_list.append((qzeros_tmp.reshape(-1, 1) >> shift_num) & 0xF) + qweight_list.append((qweight_tmp.reshape(-1, 1) >> shift_num) & 0xF) + qzeros_tmp = torch.cat(qzeros_list,dim=-1).reshape(qzeros_tmp.shape[0], -1).to(scales.dtype) + qweight_tmp = torch.cat(qweight_list,dim=-1).reshape(qweight_tmp.shape[0], -1).to(scales.dtype) + res = (qweight_tmp.reshape(qzeros_tmp.shape[0],-1,qzeros_tmp.shape[1]) - qzeros_tmp.unsqueeze(1)) * scales.unsqueeze(1) + return res.reshape(qweight_tmp.shape[0],-1) \ No newline at end of file diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index d2b4c6bfcc73..eec3e4d0b4da 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -499,6 +499,8 @@ def load_weights_and_postprocess(model, weights, target_device): # parameters onto device for processing and back off after. with device_loading_context(module, target_device): quant_method.process_weights_after_loading(module) + if is_npu: + torch.npu.empty_cache() class LayeredModelLoader(DefaultModelLoader): diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 06ebf7f785d2..587c01be9f52 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -164,6 +164,8 @@ from sglang.srt.layers.quantization.awq_triton import ( awq_dequantize_triton as awq_dequantize, ) +elif _is_npu: + from sglang.srt.layers.quantization.awq_triton import awq_dequantize_decomposition as awq_dequantize else: from vllm._custom_ops import awq_dequantize @@ -2481,7 +2483,7 @@ def post_load_weights(self, is_nextn=False, weight_names=None): ) if hasattr(self_attn.kv_b_proj, "qweight"): # AWQ compatible - if _is_cuda or _is_hip: + if _is_cuda or _is_hip or _is_npu: w = awq_dequantize( self_attn.kv_b_proj.qweight, self_attn.kv_b_proj.scales, diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 7ea3f36d5b3b..5eadae41afd6 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -444,6 +444,8 @@ def get_available_gpu_memory( f"WARNING: current device is not {gpu_id}, but {torch.npu.current_device()}, ", "which may cause useless memory allocation for torch NPU context.", ) + if empty_cache: + torch.npu.empty_cache() free_gpu_memory, total_gpu_memory = torch.npu.mem_get_info() if distributed: From c74b35bdd737f3226181ecd2914c284090159961 Mon Sep 17 00:00:00 2001 From: ErvinXie Date: Mon, 8 Sep 2025 08:17:59 +0000 Subject: [PATCH 02/16] format --- python/sglang/srt/layers/quantization/awq.py | 76 ++++++++++++------- .../srt/layers/quantization/awq_triton.py | 18 ++++- python/sglang/srt/models/deepseek_v2.py | 4 +- 3 files changed, 67 insertions(+), 31 deletions(-) diff --git a/python/sglang/srt/layers/quantization/awq.py b/python/sglang/srt/layers/quantization/awq.py index e3b082ef3b69..8cc7be50f899 100644 --- a/python/sglang/srt/layers/quantization/awq.py +++ b/python/sglang/srt/layers/quantization/awq.py @@ -117,11 +117,7 @@ def get_name(self) -> str: return "awq" def get_supported_act_dtypes(self) -> List[torch.dtype]: - return ( - [torch.float16] - if not _is_npu - else [torch.float16, torch.bfloat16] - ) + return [torch.float16] if not _is_npu else [torch.float16, torch.bfloat16] @classmethod def get_min_capability(cls) -> int: @@ -165,7 +161,7 @@ def get_quant_method( elif isinstance(layer, FusedMoE): return AWQMoEAscendMethod(self) return None - + if isinstance(layer, LinearBase): if is_layer_skipped_awq(prefix, self.modules_to_not_convert): return UnquantizedLinearMethod() @@ -593,6 +589,7 @@ def apply( bias=bias, ) + class AWQLinearAscendMethod(AWQLinearMethod): """Linear method for AWQ on Ascend. @@ -607,11 +604,13 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: qzeros_list = [] shifts = [0, 4, 1, 5, 2, 6, 3, 7] - for i in range(0,self.quant_config.pack_factor): + for i in range(0, self.quant_config.pack_factor): shift_num = shifts[i] * 4 qzeros_list.append((qzeros_tmp.reshape(-1, 1) >> shift_num) & 0xF) - qweight_tmp.bitwise_or_(((layer.qweight.data >> shift_num) * (2 ** (4*i))) & (0xF << (4*i))) - + qweight_tmp.bitwise_or_( + ((layer.qweight.data >> shift_num) * (2 ** (4 * i))) & (0xF << (4 * i)) + ) + qweight_tmp.bitwise_xor_(0x88888888) qzeros_tmp = torch.cat(qzeros_list, dim=-1).reshape(qzeros_tmp.shape[0], -1) @@ -636,7 +635,7 @@ def apply( if bias is not None and bias.dtype == torch.bfloat16: bias = bias.float() - + out = torch_npu.npu_weight_quant_batchmatmul( reshaped_x, qweight, @@ -647,7 +646,8 @@ def apply( ) return out.reshape(out_shape) - + + class AWQMoEMethod(FusedMoEMethodBase): def __init__(self, quant_config: AWQMarlinConfig): @@ -860,6 +860,7 @@ def apply( ).to(orig_dtype) return StandardCombineInput(hidden_states=output) + def npu_fused_experts( hidden_states: torch.Tensor, w13: torch.Tensor, @@ -935,6 +936,7 @@ def npu_fused_experts( final_hidden_states = final_hidden_states.view(original_shape) return final_hidden_states + class AWQMoEAscendMethod(AWQMoEMethod): def __init__(self, quant_config: AWQConfig): self.quant_config = quant_config @@ -945,32 +947,54 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: w13_qzeros_list = [] w2_qzeros_list = [] shifts = [0, 4, 1, 5, 2, 6, 3, 7] - for i in range(0,self.quant_config.pack_factor): + for i in range(0, self.quant_config.pack_factor): shift_num = shifts[i] * 4 - w13_qzeros_list.append((layer.w13_qzeros.data.reshape(-1, 1) >> shift_num) & 0xF) - w2_qzeros_list.append((layer.w2_qzeros.data.reshape(-1, 1) >> shift_num) & 0xF) - w13_qweight_tmp.bitwise_or_(((layer.w13_qweight.data >> shift_num) * (2 ** (4*i))) & (0xF << (4*i))) - w2_qweight_tmp.bitwise_or_(((layer.w2_qweight.data >> shift_num) * (2 ** (4*i))) & (0xF << (4*i))) - + w13_qzeros_list.append( + (layer.w13_qzeros.data.reshape(-1, 1) >> shift_num) & 0xF + ) + w2_qzeros_list.append( + (layer.w2_qzeros.data.reshape(-1, 1) >> shift_num) & 0xF + ) + w13_qweight_tmp.bitwise_or_( + ((layer.w13_qweight.data >> shift_num) * (2 ** (4 * i))) + & (0xF << (4 * i)) + ) + w2_qweight_tmp.bitwise_or_( + ((layer.w2_qweight.data >> shift_num) * (2 ** (4 * i))) + & (0xF << (4 * i)) + ) + w13_qweight_tmp.bitwise_xor_(0x88888888) w2_qweight_tmp.bitwise_xor_(0x88888888) - w13_qzeros_tmp = torch.cat(w13_qzeros_list, dim=-1).reshape(layer.w13_qzeros.shape[0], layer.w13_qzeros.shape[1], -1) + w13_qzeros_tmp = torch.cat(w13_qzeros_list, dim=-1).reshape( + layer.w13_qzeros.shape[0], layer.w13_qzeros.shape[1], -1 + ) w13_qzeros_tmp = -(w13_qzeros_tmp - 8) w13_qzeros_tmp = w13_qzeros_tmp.to(layer.w13_scales.data.dtype) - w2_qzeros_tmp = torch.cat(w2_qzeros_list, dim=-1).reshape(layer.w2_qzeros.shape[0], layer.w2_qzeros.shape[1], -1) + w2_qzeros_tmp = torch.cat(w2_qzeros_list, dim=-1).reshape( + layer.w2_qzeros.shape[0], layer.w2_qzeros.shape[1], -1 + ) w2_qzeros_tmp = -(w2_qzeros_tmp - 8) w2_qzeros_tmp = w2_qzeros_tmp.to(layer.w2_scales.data.dtype) - layer.register_parameter("w13_qzeros", torch.nn.Parameter(w13_qzeros_tmp, requires_grad=False)) - layer.register_parameter("w13_qweight", torch.nn.Parameter(w13_qweight_tmp, requires_grad=False)) - layer.register_parameter("w2_qzeros", torch.nn.Parameter(w2_qzeros_tmp, requires_grad=False)) - layer.register_parameter("w2_qweight", torch.nn.Parameter(w2_qweight_tmp, requires_grad=False)) + layer.register_parameter( + "w13_qzeros", torch.nn.Parameter(w13_qzeros_tmp, requires_grad=False) + ) + layer.register_parameter( + "w13_qweight", torch.nn.Parameter(w13_qweight_tmp, requires_grad=False) + ) + layer.register_parameter( + "w2_qzeros", torch.nn.Parameter(w2_qzeros_tmp, requires_grad=False) + ) + layer.register_parameter( + "w2_qweight", torch.nn.Parameter(w2_qweight_tmp, requires_grad=False) + ) def create_moe_runner( self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig ): - self.moe_runner_config = moe_runner_config + self.moe_runner_config = moe_runner_config def apply( self, @@ -980,7 +1004,7 @@ def apply( assert ( self.moe_runner_config.activation == "silu" ), "Only SiLU activation is supported." - + x = dispatch_output.hidden_states topk_output = dispatch_output.topk_output @@ -998,4 +1022,4 @@ def apply( topk_weights=topk_weights, topk_ids=topk_ids, top_k=topk_ids.shape[1], - ) \ No newline at end of file + ) diff --git a/python/sglang/srt/layers/quantization/awq_triton.py b/python/sglang/srt/layers/quantization/awq_triton.py index 1c420a05b281..b83dd79fbcc3 100644 --- a/python/sglang/srt/layers/quantization/awq_triton.py +++ b/python/sglang/srt/layers/quantization/awq_triton.py @@ -338,6 +338,7 @@ def awq_gemm_triton( return result + def awq_dequantize_decomposition( qweight: torch.Tensor, scales: torch.Tensor, @@ -352,7 +353,16 @@ def awq_dequantize_decomposition( shift_num = shifts[i] * 4 qzeros_list.append((qzeros_tmp.reshape(-1, 1) >> shift_num) & 0xF) qweight_list.append((qweight_tmp.reshape(-1, 1) >> shift_num) & 0xF) - qzeros_tmp = torch.cat(qzeros_list,dim=-1).reshape(qzeros_tmp.shape[0], -1).to(scales.dtype) - qweight_tmp = torch.cat(qweight_list,dim=-1).reshape(qweight_tmp.shape[0], -1).to(scales.dtype) - res = (qweight_tmp.reshape(qzeros_tmp.shape[0],-1,qzeros_tmp.shape[1]) - qzeros_tmp.unsqueeze(1)) * scales.unsqueeze(1) - return res.reshape(qweight_tmp.shape[0],-1) \ No newline at end of file + qzeros_tmp = ( + torch.cat(qzeros_list, dim=-1).reshape(qzeros_tmp.shape[0], -1).to(scales.dtype) + ) + qweight_tmp = ( + torch.cat(qweight_list, dim=-1) + .reshape(qweight_tmp.shape[0], -1) + .to(scales.dtype) + ) + res = ( + qweight_tmp.reshape(qzeros_tmp.shape[0], -1, qzeros_tmp.shape[1]) + - qzeros_tmp.unsqueeze(1) + ) * scales.unsqueeze(1) + return res.reshape(qweight_tmp.shape[0], -1) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 587c01be9f52..7f30248d6e4a 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -165,7 +165,9 @@ awq_dequantize_triton as awq_dequantize, ) elif _is_npu: - from sglang.srt.layers.quantization.awq_triton import awq_dequantize_decomposition as awq_dequantize + from sglang.srt.layers.quantization.awq_triton import ( + awq_dequantize_decomposition as awq_dequantize, + ) else: from vllm._custom_ops import awq_dequantize From b009a54619a6f0cbfde5c588269038eae0e7fc96 Mon Sep 17 00:00:00 2001 From: Alisehen <814073252@qq.com> Date: Mon, 8 Sep 2025 10:06:51 +0000 Subject: [PATCH 03/16] code fix --- python/sglang/srt/layers/quantization/awq.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/quantization/awq.py b/python/sglang/srt/layers/quantization/awq.py index e3b082ef3b69..2b4d3f2fd625 100644 --- a/python/sglang/srt/layers/quantization/awq.py +++ b/python/sglang/srt/layers/quantization/awq.py @@ -977,6 +977,8 @@ def apply( layer: torch.nn.Module, dispatch_output: StandardDispatchOutput, ) -> torch.Tensor: + from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput + assert ( self.moe_runner_config.activation == "silu" ), "Only SiLU activation is supported." @@ -987,7 +989,7 @@ def apply( topk_weights, topk_ids, _ = topk_output topk_ids = topk_ids.to(torch.int32) topk_weights = topk_weights.to(x.dtype) - return npu_fused_experts( + output = npu_fused_experts( hidden_states=x, w13=layer.w13_qweight, w13_scale=layer.w13_scales, @@ -998,4 +1000,5 @@ def apply( topk_weights=topk_weights, topk_ids=topk_ids, top_k=topk_ids.shape[1], - ) \ No newline at end of file + ) + return StandardCombineInput(hidden_states=output) \ No newline at end of file From 4515d259447bfd403c396a19e5d5d5bb399cdf2c Mon Sep 17 00:00:00 2001 From: xwy-sap4 Date: Tue, 16 Sep 2025 07:19:01 +0000 Subject: [PATCH 04/16] format --- python/sglang/srt/layers/quantization/awq.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/quantization/awq.py b/python/sglang/srt/layers/quantization/awq.py index 3095472ba2ad..193b135275f6 100644 --- a/python/sglang/srt/layers/quantization/awq.py +++ b/python/sglang/srt/layers/quantization/awq.py @@ -1002,7 +1002,7 @@ def apply( dispatch_output: StandardDispatchOutput, ) -> torch.Tensor: from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput - + assert ( self.moe_runner_config.activation == "silu" ), "Only SiLU activation is supported." @@ -1025,4 +1025,4 @@ def apply( topk_ids=topk_ids, top_k=topk_ids.shape[1], ) - return StandardCombineInput(hidden_states=output) \ No newline at end of file + return StandardCombineInput(hidden_states=output) From 281d0847a615b20610980ab21f5480e2dfcf8f60 Mon Sep 17 00:00:00 2001 From: xwy-sap4 Date: Wed, 17 Sep 2025 07:53:57 +0000 Subject: [PATCH 05/16] refact npu_fused_experts --- python/sglang/srt/layers/quantization/awq.py | 77 +------------------ .../srt/layers/quantization/w8a8_int8.py | 32 ++++++-- 2 files changed, 28 insertions(+), 81 deletions(-) diff --git a/python/sglang/srt/layers/quantization/awq.py b/python/sglang/srt/layers/quantization/awq.py index 193b135275f6..84dad32c54d5 100644 --- a/python/sglang/srt/layers/quantization/awq.py +++ b/python/sglang/srt/layers/quantization/awq.py @@ -7,6 +7,7 @@ import torch +from python.sglang.srt.layers.quantization.w8a8_int8 import npu_fused_experts from sglang.srt.layers.linear import LinearBase, set_weight_attrs from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter from sglang.srt.layers.quantization.base_config import ( @@ -861,82 +862,6 @@ def apply( return StandardCombineInput(hidden_states=output) -def npu_fused_experts( - hidden_states: torch.Tensor, - w13: torch.Tensor, - w13_scale: torch.Tensor, - w13_offset: torch.Tensor, - w2: torch.Tensor, - w2_scale: torch.Tensor, - w2_offset: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - top_k: int, -): - original_shape = hidden_states.shape - original_dtype = hidden_states.dtype - scale_dtype = original_dtype if original_dtype == torch.bfloat16 else torch.float32 - if len(original_shape) == 3: - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - num_tokens = hidden_states.shape[0] - num_experts = w13.shape[0] - row_idx_len = num_tokens * top_k - row_idx = ( - torch.arange(0, row_idx_len, dtype=torch.int32, device=topk_weights.device) - .view(top_k, -1) - .permute(1, 0) - .contiguous() - ) - hidden_states, expanded_row_idx, expanded_expert_idx = ( - torch_npu.npu_moe_init_routing( - hidden_states, row_idx=row_idx, expert_idx=topk_ids, active_num=num_tokens - ) - ) - expert_tokens = torch_npu.npu_moe_compute_expert_tokens( - expanded_expert_idx, num_experts - ) - expert_tokens = expert_tokens.to(torch.int64) - # gmm1: gate_up_proj - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[w13], - antiquant_scale=[w13_scale], - antiquant_offset=[w13_offset], - split_item=2, - group_list_type=0, - group_type=0, - group_list=expert_tokens, - output_dtype=original_dtype, - )[0] - # act_fn: swiglu - hidden_states = torch_npu.npu_swiglu(hidden_states) - # gmm2: down_proj - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[w2], - antiquant_scale=[w2_scale], - antiquant_offset=[w2_offset], - split_item=2, - group_list_type=0, - group_type=0, - group_list=expert_tokens, - output_dtype=original_dtype, - )[0] - - final_hidden_states = torch_npu.npu_moe_finalize_routing( - hidden_states, - skip1=None, - skip2=None, - bias=None, - scales=topk_weights, - expanded_src_to_dst_row=expanded_row_idx, - export_for_source_row=topk_ids, - ) - if len(original_shape) == 3: - final_hidden_states = final_hidden_states.view(original_shape) - return final_hidden_states - - class AWQMoEAscendMethod(AWQMoEMethod): def __init__(self, quant_config: AWQConfig): self.quant_config = quant_config diff --git a/python/sglang/srt/layers/quantization/w8a8_int8.py b/python/sglang/srt/layers/quantization/w8a8_int8.py index 5ccb0259da31..bedaf8d53e07 100644 --- a/python/sglang/srt/layers/quantization/w8a8_int8.py +++ b/python/sglang/srt/layers/quantization/w8a8_int8.py @@ -118,7 +118,15 @@ def npu_fused_experts( topk_weights: torch.Tensor, topk_ids: torch.Tensor, top_k: int, + **kwargs, ): + w13_offset = kwargs.get("w13_offset", None) + w2_offset = kwargs.get("w2_offset", None) + if w13_offset is not None or w2_offset is not None: + use_awq = True + else: + use_awq = False + original_shape = hidden_states.shape original_dtype = hidden_states.dtype scale_dtype = original_dtype if original_dtype == torch.bfloat16 else torch.float32 @@ -143,12 +151,27 @@ def npu_fused_experts( ) expert_tokens = expert_tokens.to(torch.int64) # gmm1: gate_up_proj - hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states) + if not use_awq: + hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states) + scale_args13 = { + "scale": [w13_scale.to(scale_dtype)], + "per_token_scale": [pertoken_scale], + } + scale_args2 = { + "scale": [w2_scale.to(scale_dtype)], + "per_token_scale": [pertoken_scale], + } + else: + scale_args13 = { + "antiquant_scale": [w13_scale], + "antiquant_offset": [w13_offset], + } + scale_args2 = {"antiquant_scale": [w2_scale], "antiquant_offset": [w2_offset]} + hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], weight=[w13], - scale=[w13_scale.to(scale_dtype)], - per_token_scale=[pertoken_scale], + **scale_args13, split_item=2, group_list_type=0, group_type=0, @@ -162,8 +185,7 @@ def npu_fused_experts( hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], weight=[w2], - scale=[w2_scale.to(scale_dtype)], - per_token_scale=[pertoken_scale], + **scale_args2, split_item=2, group_list_type=0, group_type=0, From 7c30f098bc9ce36d9b971399e5b2d9247736acff Mon Sep 17 00:00:00 2001 From: xwy-sap4 Date: Wed, 17 Sep 2025 08:11:46 +0000 Subject: [PATCH 06/16] minor --- python/sglang/srt/layers/quantization/w8a8_int8.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/layers/quantization/w8a8_int8.py b/python/sglang/srt/layers/quantization/w8a8_int8.py index bedaf8d53e07..8c9384e335ec 100644 --- a/python/sglang/srt/layers/quantization/w8a8_int8.py +++ b/python/sglang/srt/layers/quantization/w8a8_int8.py @@ -122,10 +122,7 @@ def npu_fused_experts( ): w13_offset = kwargs.get("w13_offset", None) w2_offset = kwargs.get("w2_offset", None) - if w13_offset is not None or w2_offset is not None: - use_awq = True - else: - use_awq = False + use_wna16 = kwargs.get("use_wna16", False) original_shape = hidden_states.shape original_dtype = hidden_states.dtype @@ -151,7 +148,7 @@ def npu_fused_experts( ) expert_tokens = expert_tokens.to(torch.int64) # gmm1: gate_up_proj - if not use_awq: + if not use_wna16: hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states) scale_args13 = { "scale": [w13_scale.to(scale_dtype)], From 074485287f4062ce9384be3b59aacb3b2c4027be Mon Sep 17 00:00:00 2001 From: xwy-sap4 Date: Wed, 17 Sep 2025 08:20:55 +0000 Subject: [PATCH 07/16] minor --- python/sglang/srt/layers/quantization/awq.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/layers/quantization/awq.py b/python/sglang/srt/layers/quantization/awq.py index 84dad32c54d5..53740da72e18 100644 --- a/python/sglang/srt/layers/quantization/awq.py +++ b/python/sglang/srt/layers/quantization/awq.py @@ -949,5 +949,6 @@ def apply( topk_weights=topk_weights, topk_ids=topk_ids, top_k=topk_ids.shape[1], + use_wna16=True, ) return StandardCombineInput(hidden_states=output) From a32ed84291e8ed6cc32f6e16e334eb74030c297d Mon Sep 17 00:00:00 2001 From: xwy-sap4 Date: Wed, 17 Sep 2025 09:14:36 +0000 Subject: [PATCH 08/16] minor --- python/sglang/srt/layers/quantization/awq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/quantization/awq.py b/python/sglang/srt/layers/quantization/awq.py index 53740da72e18..3b0fabb89594 100644 --- a/python/sglang/srt/layers/quantization/awq.py +++ b/python/sglang/srt/layers/quantization/awq.py @@ -7,7 +7,6 @@ import torch -from python.sglang.srt.layers.quantization.w8a8_int8 import npu_fused_experts from sglang.srt.layers.linear import LinearBase, set_weight_attrs from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter from sglang.srt.layers.quantization.base_config import ( @@ -32,6 +31,7 @@ ) from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.quantization.utils import get_scalar_types, replace_parameter +from sglang.srt.layers.quantization.w8a8_int8 import npu_fused_experts if TYPE_CHECKING: from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig From a3d907789f431f0183dd6f55f26b78f74606b4b6 Mon Sep 17 00:00:00 2001 From: xwy-sap4 Date: Wed, 17 Sep 2025 09:33:32 +0000 Subject: [PATCH 09/16] minor --- .../sglang/srt/layers/quantization/w8a8_int8.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/layers/quantization/w8a8_int8.py b/python/sglang/srt/layers/quantization/w8a8_int8.py index 8c9384e335ec..9bcb1e5c899f 100644 --- a/python/sglang/srt/layers/quantization/w8a8_int8.py +++ b/python/sglang/srt/layers/quantization/w8a8_int8.py @@ -154,16 +154,11 @@ def npu_fused_experts( "scale": [w13_scale.to(scale_dtype)], "per_token_scale": [pertoken_scale], } - scale_args2 = { - "scale": [w2_scale.to(scale_dtype)], - "per_token_scale": [pertoken_scale], - } else: scale_args13 = { "antiquant_scale": [w13_scale], "antiquant_offset": [w13_offset], } - scale_args2 = {"antiquant_scale": [w2_scale], "antiquant_offset": [w2_offset]} hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], @@ -177,7 +172,15 @@ def npu_fused_experts( )[0] # act_fn: swiglu hidden_states = torch_npu.npu_swiglu(hidden_states) - hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states) + if not use_wna16: + hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states) + + scale_args2 = { + "scale": [w2_scale.to(scale_dtype)], + "per_token_scale": [pertoken_scale], + } + else: + scale_args2 = {"antiquant_scale": [w2_scale], "antiquant_offset": [w2_offset]} # gmm2: down_proj hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], From 44732a99fb2b6b929e846f8ba8fbaa5aa832f7b3 Mon Sep 17 00:00:00 2001 From: Zhengda Qin Date: Sun, 28 Sep 2025 14:48:51 +0800 Subject: [PATCH 10/16] ci bug fix --- python/sglang/srt/model_loader/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index da78345bc0ce..ab2fec086632 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -511,7 +511,7 @@ def load_weights_and_postprocess(model, weights, target_device): # parameters onto device for processing and back off after. with device_loading_context(module, target_device): quant_method.process_weights_after_loading(module) - if is_npu: + if _is_npu: torch.npu.empty_cache() From a2770f2113fb097155b14fc5b9da0bdd0ad87dfd Mon Sep 17 00:00:00 2001 From: Zhengda Qin Date: Sun, 28 Sep 2025 18:12:42 +0800 Subject: [PATCH 11/16] bug fix --- python/sglang/srt/layers/rotary_embedding.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index fd69ea72776e..d0ff9f764213 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -176,6 +176,7 @@ def forward_npu( query: torch.Tensor, key: torch.Tensor, offsets: Optional[torch.Tensor] = None, + fused_set_kv_buffer_arg=None, ) -> Tuple[torch.Tensor, torch.Tensor]: """A PyTorch-npu implementation of forward().""" import os From af799de968017e85ec0d1e371426f5966765d7ad Mon Sep 17 00:00:00 2001 From: Alisehen <814073252@qq.com> Date: Sat, 11 Oct 2025 16:32:00 +0800 Subject: [PATCH 12/16] ci fix --- python/sglang/srt/models/deepseek_v2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 3e9180fbf384..e71e8979159f 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -186,6 +186,7 @@ import custom_ops import sgl_kernel_npu import torch_npu + from sglang.srt.layers.quantization.awq_triton import ( awq_dequantize_decomposition as awq_dequantize, ) From 31c6ec746f7e6224902f1f912e347b8195a6ff6a Mon Sep 17 00:00:00 2001 From: Alisehen <814073252@qq.com> Date: Sat, 11 Oct 2025 16:43:31 +0800 Subject: [PATCH 13/16] chore: apply pre-commit autofix (trailing whitespace) --- python/sglang/srt/models/deepseek_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index e71e8979159f..7ee2ae4903a4 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -186,7 +186,7 @@ import custom_ops import sgl_kernel_npu import torch_npu - + from sglang.srt.layers.quantization.awq_triton import ( awq_dequantize_decomposition as awq_dequantize, ) From 7570fd2ff6a89e4556bb5e9710a59eb2687a7ae1 Mon Sep 17 00:00:00 2001 From: ErvinXie Date: Sat, 18 Oct 2025 12:00:11 +0800 Subject: [PATCH 14/16] format --- python/sglang/srt/models/deepseek_v2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 1df8dfba2b9d..3b0a3d404938 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -185,6 +185,7 @@ import custom_ops # noqa: F401 import sgl_kernel_npu # noqa: F401 import torch_npu # noqa: F401 + from sglang.srt.layers.quantization.awq_triton import ( awq_dequantize_decomposition as awq_dequantize, ) From 168faacf74ed093378ecf246aa2ff7ee4c18d038 Mon Sep 17 00:00:00 2001 From: "xwy@sap2" Date: Sat, 18 Oct 2025 04:23:22 +0000 Subject: [PATCH 15/16] format --- python/sglang/srt/models/deepseek_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 3b0a3d404938..b0c8f13713ff 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -185,7 +185,7 @@ import custom_ops # noqa: F401 import sgl_kernel_npu # noqa: F401 import torch_npu # noqa: F401 - + from sglang.srt.layers.quantization.awq_triton import ( awq_dequantize_decomposition as awq_dequantize, ) From 0bdb55b2e20514b7ec308ed5658a1cacb664e738 Mon Sep 17 00:00:00 2001 From: Alisehen <814073252@qq.com> Date: Tue, 21 Oct 2025 11:55:19 +0800 Subject: [PATCH 16/16] format fix --- python/sglang/srt/layers/quantization/awq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/quantization/awq.py b/python/sglang/srt/layers/quantization/awq.py index 51fb45d97ed9..5b4a7536793f 100644 --- a/python/sglang/srt/layers/quantization/awq.py +++ b/python/sglang/srt/layers/quantization/awq.py @@ -40,7 +40,7 @@ CombineInput, ) -from sglang.srt.utils import is_cuda, is_hip, is_xpu, is_npu +from sglang.srt.utils import is_cuda, is_hip, is_npu, is_xpu _is_cuda = is_cuda() _is_hip = is_hip()