Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions python/sglang/srt/layers/attention/aiter_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1430,6 +1430,7 @@ def forward_extend(
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
sinks=None,
):
cache_loc = (
forward_batch.out_cache_loc
Expand Down Expand Up @@ -1798,6 +1799,10 @@ def forward_extend(
k_cache = k_cache.to(dtype)
v_cache = v_cache.to(dtype)

window_size = (-1, -1)
if layer.sliding_window_size is not None and layer.sliding_window_size > -1:
window_size = (layer.sliding_window_size, -1)

o = mha_batch_prefill_func(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_cache,
Expand All @@ -1812,6 +1817,8 @@ def forward_extend(
alibi_slopes=None,
return_lse=False,
return_attn_probs=False,
window_size=window_size,
sink_ptr=sinks,
)

return o.view(-1, layer.tp_q_head_num * layer.head_dim)
Expand Down
9 changes: 8 additions & 1 deletion python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1368,7 +1368,14 @@ def _handle_model_specific_adjustments(self):
else:
self.attention_backend = "triton"

supported_backends = ["triton", "trtllm_mha", "fa3", "fa4", "ascend"]
supported_backends = [
"triton",
"trtllm_mha",
"fa3",
"fa4",
"ascend",
"aiter",
]
prefill_attn_backend, decode_attn_backend = self.get_attention_backends()
assert (
prefill_attn_backend in supported_backends
Expand Down
Loading