diff --git a/python/sglang/srt/hardware_backend/npu/attention/ascend_gdn_backend.py b/python/sglang/srt/hardware_backend/npu/attention/ascend_gdn_backend.py index 6533699b0607..7c481e0f363c 100644 --- a/python/sglang/srt/hardware_backend/npu/attention/ascend_gdn_backend.py +++ b/python/sglang/srt/hardware_backend/npu/attention/ascend_gdn_backend.py @@ -8,6 +8,7 @@ from sgl_kernel_npu.mamba.causal_conv1d import ( causal_conv1d_fn_npu, causal_conv1d_update_npu, + causal_conv1d_update_v2, ) from sglang.srt.hardware_backend.npu.attention.ascend_hybrid_linear_attn_backend import ( @@ -224,9 +225,7 @@ def forward_extend( else: has_initial_states = forward_batch.extend_prefix_lens > 0 if is_target_verify: - draft_token_num = forward_batch.spec_info.draft_token_num num_token_padding = mixed_qkv.shape[0] - batch_size = cache_indices.shape[0] if ( not self.graph_mode and forward_batch.num_token_non_padded_cpu != num_token_padding @@ -236,23 +235,24 @@ def forward_extend( b = b[: forward_batch.num_token_non_padded_cpu] seq_len = forward_batch.num_token_non_padded_cpu - mixed_qkv_reshaped = mixed_qkv.view(batch_size, draft_token_num, -1) - num_accept_tokens = torch.full( + batch_size = cache_indices.shape[0] + draft_token_num = forward_batch.spec_info.draft_token_num + num_accepted_tokens = torch.full( (batch_size,), draft_token_num, dtype=torch.int32, device=mixed_qkv.device, ) - mixed_qkv = torch.ops.npu.causal_conv1d_update( - mixed_qkv_reshaped, - layer.conv_weights.transpose(0, 1).contiguous(), - conv_states, - cache_indices, - layer.bias, - num_accept_tokens, - None, - layer.activation == "silu", - self.pad_slot_id, + mixed_qkv = causal_conv1d_update_v2( + x=mixed_qkv.view(batch_size, draft_token_num, -1).contiguous(), + conv_state=conv_states.contiguous(), + weight=layer.conv_weights.transpose(0, 1).contiguous(), + bias=layer.bias, + activation=layer.activation, + conv_state_indices=cache_indices, + num_accepted_tokens=num_accepted_tokens, + pad_slot_id=-1, + validate_data=False, ).view(seq_len, -1) else: mixed_qkv = mixed_qkv.transpose(0, 1)