Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
283 changes: 211 additions & 72 deletions python/sglang/srt/layers/attention/aiter_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ class ForwardMetadata:
reduce_partial_map: Optional[torch.Tensor] = None
num_kv_splits: Optional[int] = None
run_graph: Optional[bool] = True
custom_mask: Optional[torch.Tensor] = None
mask_indptr: Optional[torch.Tensor] = None
max_extend_len: Optional[int] = None


global_workspace_buffer = None
Expand Down Expand Up @@ -123,7 +126,6 @@ def __init__(
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
self.head_dim = model_runner.model_config.head_dim
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
self.num_kv_head = model_runner.model_config.get_num_kv_heads(
get_attention_tp_size()
)
Expand All @@ -133,6 +135,21 @@ def __init__(

self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA

# Get v_head_dim based on model type
if self.use_mla:
# For MLA models, get v_head_dim from model config
self.v_head_dim = model_runner.model_config.v_head_dim
elif (
model_runner.hybrid_gdn_config is not None
or model_runner.kimi_linear_config is not None
):
# For hybrid linear models, layer_id = 0 may not be full attention
self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim()
else:
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[
-1
]

# Parse constants
self.max_context_len = model_runner.model_config.context_len
self.skip_prefill = skip_prefill
Expand All @@ -152,6 +169,9 @@ def __init__(
self.qo_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
self.mask_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int64, device=model_runner.device
)

# Create prefill indices updater
if not skip_prefill:
Expand Down Expand Up @@ -562,21 +582,28 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
run_graph=False,
)
else:
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
prefix_lens=None,
encoder_lens=forward_batch.encoder_lens,
spec_info=forward_batch.spec_info,
# Non-MLA draft_extend: use triton extend kernel with causal masking
kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
self.req_to_token,
)
)
kv_indices = kv_indices.to(torch.int64)
draft_max_extend_len = torch.max(spec_info.accept_length).item()

self.forward_metadata = ForwardMetadata(
self.indices_updater_prefill.kv_indptr,
self.indices_updater_prefill.kv_indices,
kv_indptr,
kv_indices,
qo_indptr,
None,
draft_max_extend_len,
None,
self.indices_updater_prefill.max_q_len,
self.indices_updater_prefill.max_kv_len,
custom_mask=custom_mask,
mask_indptr=None,
max_extend_len=draft_max_extend_len,
)
elif forward_batch.forward_mode.is_target_verify():
if self.use_mla:
Expand Down Expand Up @@ -658,21 +685,50 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
run_graph=False,
)
else:
self.indices_updater_prefill.update(
# Non-MLA target_verify: use triton extend kernel with custom mask
bs = len(forward_batch.req_pool_indices)
draft_num = spec_info.draft_token_num

qo_indptr = torch.arange(
0,
(1 + bs) * draft_num,
step=draft_num,
dtype=torch.int32,
device=self.device,
)

kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]

kv_indices = torch.empty(
kv_indptr[-1], dtype=torch.int64, device=self.device
)
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
prefix_lens=None,
encoder_lens=forward_batch.encoder_lens,
spec_info=forward_batch.spec_info,
kv_indptr,
None,
kv_indices,
self.req_to_token.stride(0),
)

custom_mask = spec_info.custom_mask
seq_mask_len = draft_num * (forward_batch.seq_lens + draft_num)
mask_indptr = self.mask_indptr
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len[:bs], dim=0)
mask_indptr = mask_indptr[: bs + 1]

self.forward_metadata = ForwardMetadata(
self.indices_updater_prefill.kv_indptr,
self.indices_updater_prefill.kv_indices,
kv_indptr,
kv_indices,
qo_indptr,
None,
draft_num,
None,
self.indices_updater_prefill.max_q_len,
self.indices_updater_prefill.max_kv_len,
custom_mask=custom_mask,
mask_indptr=mask_indptr,
max_extend_len=draft_num,
)
else:
prefix_lens = forward_batch.extend_prefix_lens
Expand Down Expand Up @@ -976,22 +1032,48 @@ def init_forward_metadata_capture_cuda_graph(
# num_kv_splits_indptr=num_kv_splits_indptr,
)
else:
seq_lens_sum = seq_lens.sum().item()
self.indices_updater_prefill.update(
# Non-MLA target_verify cuda graph: use triton extend kernel metadata
draft_num = self.num_draft_tokens
qo_indptr = self.qo_indptr[: bs + 1]
qo_indptr[: bs + 1] = torch.arange(
0,
(1 + bs) * draft_num,
step=draft_num,
dtype=torch.int32,
device=self.device,
)

kv_indptr = self.kv_indptr[: bs + 1]
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)

kv_indices = self.cuda_graph_kv_indices
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
seq_lens,
seq_lens_sum,
prefix_lens=None,
encoder_lens=encoder_lens,
spec_info=spec_info,
kv_indptr,
None,
kv_indices,
self.req_to_token.stride(0),
)

custom_mask = self.cuda_graph_custom_mask
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
seq_mask_len = draft_num * (seq_lens + draft_num)
mask_indptr = self.mask_indptr
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len[:bs], dim=0)
mask_indptr = mask_indptr[: bs + 1]

self.forward_metadata = ForwardMetadata(
self.indices_updater_prefill.kv_indptr,
self.indices_updater_prefill.kv_indices,
kv_indptr,
kv_indices,
qo_indptr,
None,
draft_num,
None,
self.indices_updater_prefill.max_q_len,
self.indices_updater_prefill.max_kv_len,
custom_mask=custom_mask,
mask_indptr=mask_indptr,
max_extend_len=draft_num,
)
elif forward_mode.is_draft_extend():
num_tokens_per_bs = self.speculative_num_steps + 1
Expand All @@ -1015,53 +1097,67 @@ def init_forward_metadata_capture_cuda_graph(
kv_indices,
self.req_to_token.stride(0),
)
kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]
max_q_len = num_tokens_per_bs

if _use_mla_ps_kernel:
if self.use_mla:
kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]
max_q_len = num_tokens_per_bs

num_kv_splits = self.max_split_per_batch
if _use_mla_ps_kernel:

self.make_mla_meta_data(
qo_indptr,
num_kv_splits = self.max_split_per_batch

self.make_mla_meta_data(
qo_indptr,
kv_indptr,
kv_last_page_len,
self.work_metadata,
self.work_info_set,
self.work_indptr,
self.reduce_indptr,
self.reduce_final_map,
self.reduce_partial_map,
max_q_len,
fast_mode=fast_mode,
max_split_per_batch=num_kv_splits,
intra_batch_mode=intra_batch_mode,
)

work_metadata = self.work_metadata
work_info_set = self.work_info_set
work_indptr = self.work_indptr

reduce_indptr = self.reduce_indptr
reduce_final_map = self.reduce_final_map
reduce_partial_map = self.reduce_partial_map

self.forward_metadata = ForwardMetadata(
kv_indptr,
kv_indices,
qo_indptr,
kv_last_page_len,
self.work_metadata,
self.work_info_set,
self.work_indptr,
self.reduce_indptr,
self.reduce_final_map,
self.reduce_partial_map,
max_q_len,
fast_mode=fast_mode,
max_split_per_batch=num_kv_splits,
intra_batch_mode=intra_batch_mode,
kv_indptr[-1].item(),
work_metadata=work_metadata,
work_info_set=work_info_set,
work_indptr=work_indptr,
reduce_indptr=reduce_indptr,
reduce_final_map=reduce_final_map,
reduce_partial_map=reduce_partial_map,
num_kv_splits=num_kv_splits,
)
else:
# Non-MLA draft_extend cuda graph: use triton extend kernel
self.forward_metadata = ForwardMetadata(
kv_indptr,
kv_indices,
qo_indptr,
None,
num_tokens_per_bs,
None,
custom_mask=None,
mask_indptr=None,
max_extend_len=num_tokens_per_bs,
)

work_metadata = self.work_metadata
work_info_set = self.work_info_set
work_indptr = self.work_indptr

reduce_indptr = self.reduce_indptr
reduce_final_map = self.reduce_final_map
reduce_partial_map = self.reduce_partial_map

self.forward_metadata = ForwardMetadata(
kv_indptr,
kv_indices,
qo_indptr,
kv_last_page_len,
max_q_len,
kv_indptr[-1].item(),
work_metadata=work_metadata,
work_info_set=work_info_set,
work_indptr=work_indptr,
reduce_indptr=reduce_indptr,
reduce_final_map=reduce_final_map,
reduce_partial_map=reduce_partial_map,
num_kv_splits=num_kv_splits,
# num_kv_splits_indptr=num_kv_splits_indptr,
)
else:
raise ValueError(f"Invalid mode: {forward_mode=}")

Expand Down Expand Up @@ -1172,7 +1268,10 @@ def init_forward_metadata_replay_cuda_graph(
dtype=torch.int32,
device=self.device,
)
kv_lens = seq_lens + self.num_draft_tokens
if self.use_mla:
kv_lens = seq_lens + self.num_draft_tokens
else:
kv_lens = seq_lens
kv_indptr = self.kv_indptr[: bs + 1]
kv_indptr[1 : bs + 1] = torch.cumsum(kv_lens, dim=0)
kv_indices = self.cuda_graph_kv_indices
Expand All @@ -1185,6 +1284,15 @@ def init_forward_metadata_replay_cuda_graph(
kv_indices,
self.req_to_token.stride(0),
)
if not self.use_mla:
# Non-MLA: update custom_mask and mask_indptr for triton extend kernel
custom_mask = self.cuda_graph_custom_mask
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
seq_mask_len = self.num_draft_tokens * (
seq_lens + self.num_draft_tokens
)
mask_indptr = self.mask_indptr[: bs + 1]
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)

kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]
max_q_len = self.num_draft_tokens
Expand Down Expand Up @@ -1642,6 +1750,37 @@ def forward_extend(
f"Invalid forward mode for MLA prefill: {forward_batch.forward_mode=}"
)
else:
if (
forward_batch.forward_mode.is_target_verify()
or forward_batch.forward_mode.is_draft_extend()
):
# Use triton extend kernel which supports custom masks and causal masking
if layer.qk_head_dim != layer.v_head_dim:
o = q.new_empty(
(q.shape[0], layer.tp_q_head_num * layer.v_head_dim)
)
else:
o = torch.empty_like(q)

self.extend_attention_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
k.contiguous(),
v.contiguous(),
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
self.forward_metadata.qo_indptr,
self.forward_metadata.kv_indptr,
self.forward_metadata.kv_indices,
self.forward_metadata.custom_mask,
True, # causal
self.forward_metadata.mask_indptr,
self.forward_metadata.max_extend_len,
layer.scaling,
logit_cap=layer.logit_cap,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)

k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
layer.layer_id
)
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/models/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,9 +385,9 @@ def _forward_input_proj(self, hidden_states: torch.Tensor):

seq_len, _ = hidden_states.shape
if (
seq_len < DUAL_STREAM_TOKEN_THRESHOLD
and self.alt_stream is not None
self.alt_stream is not None
and get_is_capture_mode()
and seq_len < DUAL_STREAM_TOKEN_THRESHOLD
):
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
Expand Down
Loading