diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py index bc941fdf7ec6..0062a55737fa 100755 --- a/python/sglang/srt/layers/attention/aiter_backend.py +++ b/python/sglang/srt/layers/attention/aiter_backend.py @@ -89,6 +89,9 @@ class ForwardMetadata: reduce_partial_map: Optional[torch.Tensor] = None num_kv_splits: Optional[int] = None run_graph: Optional[bool] = True + custom_mask: Optional[torch.Tensor] = None + mask_indptr: Optional[torch.Tensor] = None + max_extend_len: Optional[int] = None global_workspace_buffer = None @@ -123,7 +126,6 @@ def __init__( model_runner.model_config.num_attention_heads // get_attention_tp_size() ) self.head_dim = model_runner.model_config.head_dim - self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1] self.num_kv_head = model_runner.model_config.get_num_kv_heads( get_attention_tp_size() ) @@ -133,6 +135,21 @@ def __init__( self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA + # Get v_head_dim based on model type + if self.use_mla: + # For MLA models, get v_head_dim from model config + self.v_head_dim = model_runner.model_config.v_head_dim + elif ( + model_runner.hybrid_gdn_config is not None + or model_runner.kimi_linear_config is not None + ): + # For hybrid linear models, layer_id = 0 may not be full attention + self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim() + else: + self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[ + -1 + ] + # Parse constants self.max_context_len = model_runner.model_config.context_len self.skip_prefill = skip_prefill @@ -152,6 +169,9 @@ def __init__( self.qo_indptr = torch.zeros( (max_bs + 1,), dtype=torch.int32, device=model_runner.device ) + self.mask_indptr = torch.zeros( + (max_bs + 1,), dtype=torch.int64, device=model_runner.device + ) # Create prefill indices updater if not skip_prefill: @@ -562,21 +582,28 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): run_graph=False, ) else: - self.indices_updater_prefill.update( - forward_batch.req_pool_indices, - forward_batch.seq_lens, - forward_batch.seq_lens_sum, - prefix_lens=None, - encoder_lens=forward_batch.encoder_lens, - spec_info=forward_batch.spec_info, + # Non-MLA draft_extend: use triton extend kernel with causal masking + kv_indices, kv_indptr, qo_indptr, custom_mask = ( + spec_info.generate_attn_arg_prefill( + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.seq_lens_sum, + self.req_to_token, + ) ) + kv_indices = kv_indices.to(torch.int64) + draft_max_extend_len = torch.max(spec_info.accept_length).item() + self.forward_metadata = ForwardMetadata( - self.indices_updater_prefill.kv_indptr, - self.indices_updater_prefill.kv_indices, + kv_indptr, + kv_indices, + qo_indptr, None, + draft_max_extend_len, None, - self.indices_updater_prefill.max_q_len, - self.indices_updater_prefill.max_kv_len, + custom_mask=custom_mask, + mask_indptr=None, + max_extend_len=draft_max_extend_len, ) elif forward_batch.forward_mode.is_target_verify(): if self.use_mla: @@ -658,21 +685,50 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): run_graph=False, ) else: - self.indices_updater_prefill.update( + # Non-MLA target_verify: use triton extend kernel with custom mask + bs = len(forward_batch.req_pool_indices) + draft_num = spec_info.draft_token_num + + qo_indptr = torch.arange( + 0, + (1 + bs) * draft_num, + step=draft_num, + dtype=torch.int32, + device=self.device, + ) + + kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + + kv_indices = torch.empty( + kv_indptr[-1], dtype=torch.int64, device=self.device + ) + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, forward_batch.req_pool_indices, forward_batch.seq_lens, - forward_batch.seq_lens_sum, - prefix_lens=None, - encoder_lens=forward_batch.encoder_lens, - spec_info=forward_batch.spec_info, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), ) + + custom_mask = spec_info.custom_mask + seq_mask_len = draft_num * (forward_batch.seq_lens + draft_num) + mask_indptr = self.mask_indptr + mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len[:bs], dim=0) + mask_indptr = mask_indptr[: bs + 1] + self.forward_metadata = ForwardMetadata( - self.indices_updater_prefill.kv_indptr, - self.indices_updater_prefill.kv_indices, + kv_indptr, + kv_indices, + qo_indptr, None, + draft_num, None, - self.indices_updater_prefill.max_q_len, - self.indices_updater_prefill.max_kv_len, + custom_mask=custom_mask, + mask_indptr=mask_indptr, + max_extend_len=draft_num, ) else: prefix_lens = forward_batch.extend_prefix_lens @@ -976,22 +1032,48 @@ def init_forward_metadata_capture_cuda_graph( # num_kv_splits_indptr=num_kv_splits_indptr, ) else: - seq_lens_sum = seq_lens.sum().item() - self.indices_updater_prefill.update( + # Non-MLA target_verify cuda graph: use triton extend kernel metadata + draft_num = self.num_draft_tokens + qo_indptr = self.qo_indptr[: bs + 1] + qo_indptr[: bs + 1] = torch.arange( + 0, + (1 + bs) * draft_num, + step=draft_num, + dtype=torch.int32, + device=self.device, + ) + + kv_indptr = self.kv_indptr[: bs + 1] + kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0) + + kv_indices = self.cuda_graph_kv_indices + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, req_pool_indices, seq_lens, - seq_lens_sum, - prefix_lens=None, - encoder_lens=encoder_lens, - spec_info=spec_info, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), ) + + custom_mask = self.cuda_graph_custom_mask + custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask + seq_mask_len = draft_num * (seq_lens + draft_num) + mask_indptr = self.mask_indptr + mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len[:bs], dim=0) + mask_indptr = mask_indptr[: bs + 1] + self.forward_metadata = ForwardMetadata( - self.indices_updater_prefill.kv_indptr, - self.indices_updater_prefill.kv_indices, + kv_indptr, + kv_indices, + qo_indptr, None, + draft_num, None, - self.indices_updater_prefill.max_q_len, - self.indices_updater_prefill.max_kv_len, + custom_mask=custom_mask, + mask_indptr=mask_indptr, + max_extend_len=draft_num, ) elif forward_mode.is_draft_extend(): num_tokens_per_bs = self.speculative_num_steps + 1 @@ -1015,53 +1097,67 @@ def init_forward_metadata_capture_cuda_graph( kv_indices, self.req_to_token.stride(0), ) - kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs] - max_q_len = num_tokens_per_bs - if _use_mla_ps_kernel: + if self.use_mla: + kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs] + max_q_len = num_tokens_per_bs - num_kv_splits = self.max_split_per_batch + if _use_mla_ps_kernel: - self.make_mla_meta_data( - qo_indptr, + num_kv_splits = self.max_split_per_batch + + self.make_mla_meta_data( + qo_indptr, + kv_indptr, + kv_last_page_len, + self.work_metadata, + self.work_info_set, + self.work_indptr, + self.reduce_indptr, + self.reduce_final_map, + self.reduce_partial_map, + max_q_len, + fast_mode=fast_mode, + max_split_per_batch=num_kv_splits, + intra_batch_mode=intra_batch_mode, + ) + + work_metadata = self.work_metadata + work_info_set = self.work_info_set + work_indptr = self.work_indptr + + reduce_indptr = self.reduce_indptr + reduce_final_map = self.reduce_final_map + reduce_partial_map = self.reduce_partial_map + + self.forward_metadata = ForwardMetadata( kv_indptr, + kv_indices, + qo_indptr, kv_last_page_len, - self.work_metadata, - self.work_info_set, - self.work_indptr, - self.reduce_indptr, - self.reduce_final_map, - self.reduce_partial_map, max_q_len, - fast_mode=fast_mode, - max_split_per_batch=num_kv_splits, - intra_batch_mode=intra_batch_mode, + kv_indptr[-1].item(), + work_metadata=work_metadata, + work_info_set=work_info_set, + work_indptr=work_indptr, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, + num_kv_splits=num_kv_splits, + ) + else: + # Non-MLA draft_extend cuda graph: use triton extend kernel + self.forward_metadata = ForwardMetadata( + kv_indptr, + kv_indices, + qo_indptr, + None, + num_tokens_per_bs, + None, + custom_mask=None, + mask_indptr=None, + max_extend_len=num_tokens_per_bs, ) - - work_metadata = self.work_metadata - work_info_set = self.work_info_set - work_indptr = self.work_indptr - - reduce_indptr = self.reduce_indptr - reduce_final_map = self.reduce_final_map - reduce_partial_map = self.reduce_partial_map - - self.forward_metadata = ForwardMetadata( - kv_indptr, - kv_indices, - qo_indptr, - kv_last_page_len, - max_q_len, - kv_indptr[-1].item(), - work_metadata=work_metadata, - work_info_set=work_info_set, - work_indptr=work_indptr, - reduce_indptr=reduce_indptr, - reduce_final_map=reduce_final_map, - reduce_partial_map=reduce_partial_map, - num_kv_splits=num_kv_splits, - # num_kv_splits_indptr=num_kv_splits_indptr, - ) else: raise ValueError(f"Invalid mode: {forward_mode=}") @@ -1172,7 +1268,10 @@ def init_forward_metadata_replay_cuda_graph( dtype=torch.int32, device=self.device, ) - kv_lens = seq_lens + self.num_draft_tokens + if self.use_mla: + kv_lens = seq_lens + self.num_draft_tokens + else: + kv_lens = seq_lens kv_indptr = self.kv_indptr[: bs + 1] kv_indptr[1 : bs + 1] = torch.cumsum(kv_lens, dim=0) kv_indices = self.cuda_graph_kv_indices @@ -1185,6 +1284,15 @@ def init_forward_metadata_replay_cuda_graph( kv_indices, self.req_to_token.stride(0), ) + if not self.use_mla: + # Non-MLA: update custom_mask and mask_indptr for triton extend kernel + custom_mask = self.cuda_graph_custom_mask + custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask + seq_mask_len = self.num_draft_tokens * ( + seq_lens + self.num_draft_tokens + ) + mask_indptr = self.mask_indptr[: bs + 1] + mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0) kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs] max_q_len = self.num_draft_tokens @@ -1642,6 +1750,37 @@ def forward_extend( f"Invalid forward mode for MLA prefill: {forward_batch.forward_mode=}" ) else: + if ( + forward_batch.forward_mode.is_target_verify() + or forward_batch.forward_mode.is_draft_extend() + ): + # Use triton extend kernel which supports custom masks and causal masking + if layer.qk_head_dim != layer.v_head_dim: + o = q.new_empty( + (q.shape[0], layer.tp_q_head_num * layer.v_head_dim) + ) + else: + o = torch.empty_like(q) + + self.extend_attention_fwd( + q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + k.contiguous(), + v.contiguous(), + o.view(-1, layer.tp_q_head_num, layer.v_head_dim), + forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id), + forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id), + self.forward_metadata.qo_indptr, + self.forward_metadata.kv_indptr, + self.forward_metadata.kv_indices, + self.forward_metadata.custom_mask, + True, # causal + self.forward_metadata.mask_indptr, + self.forward_metadata.max_extend_len, + layer.scaling, + logit_cap=layer.logit_cap, + ) + return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) + k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer( layer.layer_id ) diff --git a/python/sglang/srt/models/qwen3_next.py b/python/sglang/srt/models/qwen3_next.py index 39a086c986b4..11be61b860f2 100644 --- a/python/sglang/srt/models/qwen3_next.py +++ b/python/sglang/srt/models/qwen3_next.py @@ -385,9 +385,9 @@ def _forward_input_proj(self, hidden_states: torch.Tensor): seq_len, _ = hidden_states.shape if ( - seq_len < DUAL_STREAM_TOKEN_THRESHOLD - and self.alt_stream is not None + self.alt_stream is not None and get_is_capture_mode() + and seq_len < DUAL_STREAM_TOKEN_THRESHOLD ): current_stream = torch.cuda.current_stream() self.alt_stream.wait_stream(current_stream)