diff --git a/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py b/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py index 8a343d43b4f8..fa82245ed31e 100644 --- a/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py +++ b/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py @@ -5,6 +5,10 @@ import torch import torch_npu +from sgl_kernel_npu.attention.sinks_attention import ( + attention_sinks_prefill_triton, + attention_sinks_triton, +) from sglang.srt.configs.model_config import AttentionArch from sglang.srt.hardware_backend.npu.attention.mla_preprocess import ( @@ -260,9 +264,17 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): // self.page_size ) if forward_batch.extend_seq_lens is not None: + self.forward_metadata.extend_seq_lens = forward_batch.extend_seq_lens self.forward_metadata.extend_seq_lens_cpu_int = ( forward_batch.extend_seq_lens.cpu().int() ) + if forward_batch.seq_lens is not None: + self.forward_metadata.seq_lens = forward_batch.seq_lens.int() + else: + self.forward_metadata.seq_lens = forward_batch.seq_lens_cpu.to( + self.device + ).int() + self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int() if ( not forward_batch.forward_mode.is_draft_extend_v2() @@ -576,6 +588,7 @@ def forward_extend( q_rope: Optional[torch.Tensor] = None, k_rope: Optional[torch.Tensor] = None, topk_indices: Optional[torch.Tensor] = None, + sinks: Optional[torch.Tensor] = None, ): if topk_indices is not None: return self.forward_sparse( @@ -617,6 +630,22 @@ def forward_extend( k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id) + if sinks is not None: + attn_out = attention_sinks_prefill_triton( + q, + k_cache, + v_cache, + sinks, + self.forward_metadata.extend_seq_lens, + self.forward_metadata.block_tables, + self.forward_metadata.seq_lens, + layer.scaling, + layer.sliding_window_size, + layer.tp_q_head_num, + layer.tp_k_head_num, + ) + return attn_out + if self.use_fia: """FIA will support multi-bs in the later version of CANN""" q = q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim) @@ -1036,6 +1065,7 @@ def forward_decode_graph( save_kv_cache: bool = True, q_rope: Optional[torch.Tensor] = None, k_rope: Optional[torch.Tensor] = None, + sinks: Optional[torch.Tensor] = None, ): if save_kv_cache: if self.use_mla: @@ -1049,6 +1079,24 @@ def forward_decode_graph( layer, forward_batch.out_cache_loc, k, v ) + if sinks is not None: + k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id) + + attn_out = attention_sinks_triton( + q, + k_cache, + v_cache, + sinks, + self.forward_metadata.block_tables, + self.forward_metadata.seq_lens, + layer.scaling, + layer.sliding_window_size, + layer.tp_q_head_num, + layer.tp_k_head_num, + ) + return attn_out + if not self.use_mla: num_tokens = q.shape[0] """PA will support bs