diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py index 4183ef227086..65e5af5efd3c 100755 --- a/python/sglang/srt/layers/attention/aiter_backend.py +++ b/python/sglang/srt/layers/attention/aiter_backend.py @@ -1430,6 +1430,7 @@ def forward_extend( layer: RadixAttention, forward_batch: ForwardBatch, save_kv_cache=True, + sinks=None, ): cache_loc = ( forward_batch.out_cache_loc @@ -1798,6 +1799,10 @@ def forward_extend( k_cache = k_cache.to(dtype) v_cache = v_cache.to(dtype) + window_size = (-1, -1) + if layer.sliding_window_size is not None and layer.sliding_window_size > -1: + window_size = (layer.sliding_window_size, -1) + o = mha_batch_prefill_func( q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), k_cache, @@ -1812,6 +1817,8 @@ def forward_extend( alibi_slopes=None, return_lse=False, return_attn_probs=False, + window_size=window_size, + sink_ptr=sinks, ) return o.view(-1, layer.tp_q_head_num * layer.head_dim) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index f5e388fe92a1..4ad77cbae8a2 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1368,7 +1368,14 @@ def _handle_model_specific_adjustments(self): else: self.attention_backend = "triton" - supported_backends = ["triton", "trtllm_mha", "fa3", "fa4", "ascend"] + supported_backends = [ + "triton", + "trtllm_mha", + "fa3", + "fa4", + "ascend", + "aiter", + ] prefill_attn_backend, decode_attn_backend = self.get_attention_backends() assert ( prefill_attn_backend in supported_backends