diff --git a/python/sglang/srt/hardware_backend/npu/modules/deepseek_v2_attention_mla_npu.py b/python/sglang/srt/hardware_backend/npu/modules/deepseek_v2_attention_mla_npu.py index 470617585c97..149b4a939249 100644 --- a/python/sglang/srt/hardware_backend/npu/modules/deepseek_v2_attention_mla_npu.py +++ b/python/sglang/srt/hardware_backend/npu/modules/deepseek_v2_attention_mla_npu.py @@ -229,19 +229,17 @@ def forward_mla_core_npu( k_rope=k_pe, **(dict(topk_indices=topk_indices) if topk_indices is not None else {}), ) - attn_output = attn_output.view(-1, m.num_local_heads, m.kv_lora_rank) attn_bmm_output = torch.empty( - (attn_output.shape[0], m.num_local_heads * m.v_head_dim), + (attn_output.shape[0], m.num_local_heads, m.v_head_dim), dtype=attn_output.dtype, device=attn_output.device, ) - torch.bmm( - attn_output.transpose(0, 1), - m.w_vc, - out=attn_bmm_output.view(-1, m.num_local_heads, m.v_head_dim).transpose(0, 1), - ) + torch.ops.npu.batch_matmul_transpose(attn_output, m.w_vc, attn_bmm_output) + + attn_bmm_output = attn_bmm_output.reshape(-1, m.num_local_heads * m.v_head_dim) + output, _ = m.o_proj(attn_bmm_output) return output @@ -358,7 +356,11 @@ def forward_dsa_core_npu( device=attn_output.device, ) - if not forward_batch.forward_mode.is_decode(): + if ( + not forward_batch.forward_mode.is_decode() + and not forward_batch.forward_mode.is_draft_extend(include_v2=True) + and not forward_batch.forward_mode.is_target_verify() + ): attn_output = attn_output.transpose(0, 1) torch.bmm( attn_output, diff --git a/python/sglang/srt/hardware_backend/npu/quantization/fused_moe_method_npu.py b/python/sglang/srt/hardware_backend/npu/quantization/fused_moe_method_npu.py index 938314b0f425..275a0bb80314 100644 --- a/python/sglang/srt/hardware_backend/npu/quantization/fused_moe_method_npu.py +++ b/python/sglang/srt/hardware_backend/npu/quantization/fused_moe_method_npu.py @@ -569,9 +569,9 @@ def apply_without_routing_weights( ): hidden_states = torch.ops.npu.npu_grouped_matmul( x=[hidden_states], - weight=[self.w13_weight], - scale=[self.w13_weight_scale], - bias=[self.w13_scale_bias], + weight=[layer.w13_weight], + scale=[layer.w13_weight_scale], + bias=[layer.w13_scale_bias], per_token_scale=[hidden_states_scale], group_list=group_list, split_item=2, @@ -586,9 +586,9 @@ def apply_without_routing_weights( hidden_states = torch.ops.npu.npu_grouped_matmul( x=[hidden_states], - weight=[self.w2_weight], - scale=[self.w2_weight_scale], - bias=[self.w2_scale_bias], + weight=[layer.w2_weight], + scale=[layer.w2_weight_scale], + bias=[layer.w2_scale_bias], per_token_scale=[swiglu_out_scale], group_list=group_list, split_item=2,