From 80a3287e5d72ad46b1155b004c48ee1ab962bcb0 Mon Sep 17 00:00:00 2001 From: iridiumine Date: Thu, 7 May 2026 08:06:46 +0000 Subject: [PATCH] use causal_conv1d_update_v2 instead of torch.ops.npu.causal_conv1d_update --- .../npu/attention/ascend_gdn_backend.py | 26 +++++++++---------- .../srt/layers/attention/triton_backend.py | 5 +++- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/python/sglang/srt/hardware_backend/npu/attention/ascend_gdn_backend.py b/python/sglang/srt/hardware_backend/npu/attention/ascend_gdn_backend.py index 6abe5c2fd042..5aaa10d2bb7b 100644 --- a/python/sglang/srt/hardware_backend/npu/attention/ascend_gdn_backend.py +++ b/python/sglang/srt/hardware_backend/npu/attention/ascend_gdn_backend.py @@ -8,6 +8,7 @@ from sgl_kernel_npu.mamba.causal_conv1d import ( causal_conv1d_fn_npu, causal_conv1d_update_npu, + causal_conv1d_update_v2, ) from sglang.srt.hardware_backend.npu.attention.ascend_hybrid_linear_attn_backend import ( @@ -224,9 +225,7 @@ def forward_extend( else: has_initial_states = forward_batch.extend_prefix_lens > 0 if is_target_verify: - draft_token_num = forward_batch.spec_info.draft_token_num num_token_padding = mixed_qkv.shape[0] - batch_size = cache_indices.shape[0] if ( not self.graph_mode and forward_batch.num_token_non_padded_cpu != num_token_padding @@ -236,23 +235,24 @@ def forward_extend( b = b[: forward_batch.num_token_non_padded_cpu] seq_len = forward_batch.num_token_non_padded_cpu - mixed_qkv_reshaped = mixed_qkv.view(batch_size, draft_token_num, -1) + batch_size = cache_indices.shape[0] + draft_token_num = forward_batch.spec_info.draft_token_num num_accepted_tokens = torch.full( (batch_size,), draft_token_num, dtype=torch.int32, device=mixed_qkv.device, ) - mixed_qkv = torch.ops.npu.causal_conv1d_update( - mixed_qkv_reshaped, - layer.conv_weights.transpose(0, 1).contiguous(), - conv_states, - cache_indices, - layer.bias, - num_accepted_tokens, - None, - layer.activation == "silu", - self.pad_slot_id, + mixed_qkv = causal_conv1d_update_v2( + x=mixed_qkv.view(batch_size, draft_token_num, -1).contiguous(), + conv_state=conv_states.contiguous(), + weight=layer.conv_weights.transpose(0, 1).contiguous(), + bias=layer.bias, + activation=layer.activation, + conv_state_indices=cache_indices, + num_accepted_tokens=num_accepted_tokens, + pad_slot_id=-1, + validate_data=False, ).view(seq_len, -1) else: mixed_qkv = mixed_qkv.transpose(0, 1) diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 9747787c111d..4c7d267d9235 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -6,7 +6,6 @@ import torch import triton import triton.language as tl -from sgl_kernel.utils import is_arch_support_pdl from sglang.srt.configs.model_config import AttentionArch from sglang.srt.layers.attention.base_attn_backend import AttentionBackend @@ -20,9 +19,13 @@ get_bool_env_var, get_device_core_count, get_int_env_var, + is_npu, next_power_of_2, ) +if not is_npu(): + from sgl_kernel.utils import is_arch_support_pdl + if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner