diff --git a/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py b/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py index 22d9021f286c..2c804058a16f 100644 --- a/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py +++ b/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py @@ -297,6 +297,8 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): if ( self.use_mla and forward_batch.forward_mode.is_extend() + and not forward_batch.forward_mode.is_draft_extend(include_v2=True) + and not forward_batch.forward_mode.is_target_verify() and sum(forward_batch.extend_prefix_lens_cpu) > 0 ): self.forward_metadata.prefix_lens = forward_batch.extend_prefix_lens.to( diff --git a/python/sglang/srt/hardware_backend/npu/quantization/fused_moe_method_npu.py b/python/sglang/srt/hardware_backend/npu/quantization/fused_moe_method_npu.py index 938314b0f425..279700b3e542 100644 --- a/python/sglang/srt/hardware_backend/npu/quantization/fused_moe_method_npu.py +++ b/python/sglang/srt/hardware_backend/npu/quantization/fused_moe_method_npu.py @@ -567,11 +567,13 @@ def apply_without_routing_weights( group_list, output_dtype, ): + from sgl_kernel_npu.activation.swiglu_quant import swiglu_quant + hidden_states = torch.ops.npu.npu_grouped_matmul( x=[hidden_states], - weight=[self.w13_weight], - scale=[self.w13_weight_scale], - bias=[self.w13_scale_bias], + weight=[layer.w13_weight], + scale=[layer.w13_weight_scale], + bias=[layer.w13_scale_bias], per_token_scale=[hidden_states_scale], group_list=group_list, split_item=2, @@ -580,15 +582,15 @@ def apply_without_routing_weights( output_dtype=output_dtype, )[0] - # act_fn: swiglu - hidden_states = torch.ops.npu.npu_swiglu(hidden_states) - hidden_states, swiglu_out_scale = torch.ops.npu.npu_dynamic_quant(hidden_states) + hidden_states, swiglu_out_scale = swiglu_quant( + hidden_states, group_list, group_list_type + ) hidden_states = torch.ops.npu.npu_grouped_matmul( x=[hidden_states], - weight=[self.w2_weight], - scale=[self.w2_weight_scale], - bias=[self.w2_scale_bias], + weight=[layer.w2_weight], + scale=[layer.w2_weight_scale], + bias=[layer.w2_scale_bias], per_token_scale=[swiglu_out_scale], group_list=group_list, split_item=2, diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 97e65433144d..53c832e1d1bd 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -6,6 +6,9 @@ import torch from sglang.srt.environ import envs +from sglang.srt.hardware_backend.npu.quantization.fused_moe_method_npu import ( + NPUW4A16Int4DynamicMoEMethod, +) from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.layers.moe import ( get_deepep_mode, @@ -347,7 +350,9 @@ def forward_npu( ) else: input_quant = get_bool_env_var("DEEP_NORMAL_MODE_USE_INT8_QUANT") - if not input_quant and self.w13_weight.dtype != torch.int32: + if not input_quant and not isinstance( + self.quant_method, NPUW4A16Int4DynamicMoEMethod + ): hidden_states, hidden_states_scale = torch_npu.npu_dynamic_quant( hidden_states )