Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
f43110a
initial support
Ying1123 Aug 5, 2025
2a91ff1
Add attention sink support init
yyihuang Aug 5, 2025
5be26e2
Merge branch 'oss-oai' of https://github.com/sgl-project/sglang into …
yyihuang Aug 5, 2025
a9cd1b4
fix compatibility
Ying1123 Aug 5, 2025
5ae082e
minor
Ying1123 Aug 5, 2025
52c9a34
Merge branch 'main' into oss-oai
zhyncs Aug 5, 2025
bee0922
fix ci
Ying1123 Aug 5, 2025
5bb34cf
Merge branch 'oss-oai' of https://github.com/sgl-project/sglang into …
yyihuang Aug 5, 2025
e2efeca
fix ci
Ying1123 Aug 5, 2025
473a74b
fix sinks
yyihuang Aug 5, 2025
7faa241
Merge branch 'oss-oai' into attn-sink
zhyncs Aug 5, 2025
a9ccc30
Merge branch 'main' of https://github.com/sgl-project/sglang into att…
yyihuang Aug 5, 2025
6f1ee5c
upd
yyihuang Aug 5, 2025
aeb544d
ckpt: trtllm decode fail
yyihuang Aug 5, 2025
355a6be
fix window size and attn backend choice
yyihuang Aug 6, 2025
3955ae0
minor update: cuda graph
yyihuang Aug 6, 2025
f8e0704
Merge branch 'main' of https://github.com/sgl-project/sglang into att…
yyihuang Aug 6, 2025
85421cc
upd
yyihuang Aug 6, 2025
3b0b1f5
Merge branch 'main' of https://github.com/sgl-project/sglang into att…
yyihuang Aug 6, 2025
88a867c
upd
yyihuang Aug 6, 2025
38f5d79
fix cuda graph error
yyihuang Aug 6, 2025
e46b808
Merge branch 'main' of https://github.com/sgl-project/sglang into att…
yyihuang Aug 6, 2025
175125c
Merge branch 'main' of https://github.com/sgl-project/sglang into att…
yyihuang Aug 7, 2025
9ee038a
upd
yyihuang Aug 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 19 additions & 15 deletions python/sglang/srt/layers/attention/trtllm_mha_backend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

"""
Support attention backend for TRTLLM MLA kernels from flashinfer.
Support attention backend for TRTLLM MHA kernels from flashinfer.
The kernel supports sm100 only, with sliding window and attention sink features.
"""

from dataclasses import dataclass
Expand Down Expand Up @@ -57,11 +58,6 @@ def __init__(

# MHA-specific dimensions
self.max_context_len = model_runner.model_config.context_len
self.sliding_window_size = (
model_runner.sliding_window_size
if model_runner.sliding_window_size is not None
else -1 # -1 indicates full attention
)
self.hidden_size = config.hidden_size

# Runtime parameters
Expand Down Expand Up @@ -117,10 +113,10 @@ def init_forward_metadata_capture_cuda_graph(
metadata = TRTLLMMHAMetadata()

# Get sequence information
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
metadata.cache_seqlens_int32 = seq_lens[:bs].to(torch.int32)

# Precompute maximum sequence length
metadata.max_seq_len_k = seq_lens.max().item()
metadata.max_seq_len_k = self.max_context_len

# Precompute page table
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][:bs, :]
Expand Down Expand Up @@ -149,7 +145,7 @@ def init_forward_metadata_replay_cuda_graph(
metadata = self.decode_cuda_graph_metadata[bs]
max_len = seq_lens_cpu.max().item()
max_seq_pages = (max_len + self.page_size - 1) // self.page_size
metadata.max_seq_len_k = max_len
metadata.max_seq_len_k = self.max_context_len

metadata.cache_seqlens_int32.copy_(seq_lens)
page_indices = self.req_to_token[
Expand Down Expand Up @@ -217,6 +213,7 @@ def forward_decode(
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
**kwargs,
) -> torch.Tensor:
"""Run forward for decode using TRTLLM MHA kernel."""
cache_loc = forward_batch.out_cache_loc
Expand All @@ -228,7 +225,7 @@ def forward_decode(
q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
# shape conversion:
# [bs, page_size, num_kv_heads, head_dim] -> [bs, num_kv_heads, page_size, head_dim]
# [num_pages, page_size, num_kv_heads, head_dim] -> [num_pages, num_kv_heads, page_size, head_dim]
k_cache = k_cache.view(
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
).permute(0, 2, 1, 3)
Expand All @@ -237,7 +234,7 @@ def forward_decode(
).permute(0, 2, 1, 3)
kv_cache = (k_cache, v_cache)

# TODO: bmm1_scale and bmm2_scale might require modification
# TODO: add support for quantization
q_scale = 1.0
k_scale = (
layer.k_scale_float
Expand All @@ -246,6 +243,8 @@ def forward_decode(
)
bmm1_scale = q_scale * k_scale * layer.scaling
bmm2_scale = 1.0
# sink: additional value per head in the denominator of the softmax.
attention_sink = kwargs.get("sinks", None)

# Call TRT-LLM kernel
# raw_out: like q, [bs, acc_q_len, num_q_heads, head_dim] but with output dtype
Expand All @@ -258,8 +257,9 @@ def forward_decode(
max_seq_len=self.forward_metadata.max_seq_len_k,
bmm1_scale=bmm1_scale,
bmm2_scale=bmm2_scale,
window_left=self.sliding_window_size,
window_left=layer.sliding_window_size,
# TODO: add attention_sink operation or nvfp4 scale factor if needed
sinks=attention_sink,
)

return o.view(-1, layer.tp_q_head_num * layer.head_dim)
Expand All @@ -272,13 +272,15 @@ def forward_extend(
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
**kwargs,
):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function forward_extend returns a torch.Tensor, but it's missing a return type hint. Adding it improves code readability and allows static analysis tools to catch potential bugs.

):
        cache_loc = forward_batch.out_cache_loc

cache_loc = forward_batch.out_cache_loc
if save_kv_cache and k is not None:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
# [num_pages, page_size, num_kv_heads, head_dim] -> [num_pages, num_kv_heads, page_size, head_dim]
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
k_cache = k_cache.view(
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
Expand All @@ -288,8 +290,9 @@ def forward_extend(
).permute(0, 2, 1, 3)
kv_cache = (k_cache, v_cache)

# TODO: bmm1_scale and bmm2_scale might require modification
# TODO: Change once quantization is supported
# sink: additional value per head in the denominator of the softmax.
attention_sink = kwargs.get("sinks", None)
# TODO: add support for quantization
q_scale = 1.0
k_scale = (
layer.k_scale_float
Expand All @@ -312,8 +315,9 @@ def forward_extend(
batch_size=forward_batch.batch_size,
cum_seq_lens_q=self.forward_metadata.cu_seqlens_q,
cum_seq_lens_kv=self.forward_metadata.cu_seqlens_k,
window_left=self.sliding_window_size,
window_left=layer.sliding_window_size,
# TODO: add attention_sink operation or nvfp4 scale factor if needed
sinks=attention_sink,
)

return o.view(-1, layer.tp_q_head_num * layer.head_dim)
6 changes: 3 additions & 3 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1443,13 +1443,13 @@ def _get_attention_backend_from_str(self, backend_str: str):
)

return CutlassMLABackend(self)
elif self.server_args.attention_backend == "trtllm_mla":
elif backend_str == "trtllm_mla":
if not self.use_mla_backend:
raise ValueError("trtllm_mla backend can only be used with MLA models.")
from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend

return TRTLLMMLABackend(self)
elif self.server_args.attention_backend == "trtllm_mha":
elif backend_str == "trtllm_mha":
if self.use_mla_backend:
raise ValueError(
"trtllm_mha backend can only be used with non-MLA models."
Expand All @@ -1460,7 +1460,7 @@ def _get_attention_backend_from_str(self, backend_str: str):

return TRTLLMHAAttnBackend(self)

elif self.server_args.attention_backend == "intel_amx":
elif backend_str == "intel_amx":
from sglang.srt.layers.attention.intel_amx_backend import (
IntelAMXAttnBackend,
)
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def forward_core(self, intermediate_state):
hidden_states, forward_batch, inner_state = intermediate_state
if inner_state is None:
return hidden_states
attn_output = self.attn(*inner_state, sinks=self.sinks)
attn_output = self.attn(*inner_state, sinks=self.sinks.to(torch.float32))
output, _ = self.o_proj(attn_output)
return output

Expand Down
17 changes: 14 additions & 3 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,11 @@ def print_deprecated_warning(message: str):
"trtllm_mla backend does not support speculative decoding yet."
)

if self.attention_backend == "trtllm_mha":
if (
self.attention_backend == "trtllm_mha"
or self.decode_attention_backend == "trtllm_mha"
or self.prefill_attention_backend == "trtllm_mha"
):
if not is_sm100_supported():
raise ValueError(
"TRTLLM MHA backend is only supported on Blackwell GPUs (SM100). Please use a different backend."
Expand All @@ -459,11 +463,18 @@ def print_deprecated_warning(message: str):

if self.speculative_algorithm is not None:
raise ValueError(
"trtllm_mla backend does not support speculative decoding yet."
"trtllm_mha backend does not support speculative decoding yet."
)

model_arch = self.get_hf_config().architectures[0]
if model_arch in ["GptOssForCausalLM"]:
self.attention_backend = "triton"
if self.attention_backend is None:
# default is triton, but we could have trtllm_mha as an option
self.attention_backend = "triton"
assert (
self.attention_backend == "trtllm_mha"
or self.attention_backend == "triton"
)

# Check if FlashInfer MXFP4 MoE is enabled
from sglang.srt.utils import get_bool_env_var
Expand Down
Loading