diff --git a/tensorrt_llm/_torch/models/modeling_qwen3_next.py b/tensorrt_llm/_torch/models/modeling_qwen3_next.py index c6bac044f3a..e8b2021fb6f 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen3_next.py +++ b/tensorrt_llm/_torch/models/modeling_qwen3_next.py @@ -50,9 +50,10 @@ from ..modules.linear import Linear, TensorParallelMode from ..modules.mamba.causal_conv1d import causal_conv1d_fn, causal_conv1d_update from ..modules.mamba.layernorm_gated import RMSNorm as RMSNormGated +from ..modules.multi_stream_utils import maybe_execute_in_parallel from ..modules.rms_norm import RMSNorm from ..speculative import SpecMetadata -from ..utils import AuxStreamType +from ..utils import AuxStreamType, EventType from .modeling_qwen3 import Qwen3Attention from .modeling_speculative import SpecDecOneEngineForCausalLM from .modeling_utils import DecoderModel, EagerFusionConfig, register_auto_model @@ -387,6 +388,7 @@ def __init__( self.mapping = model_config.mapping self.allreduce = AllReduce(mapping=model_config.mapping, strategy=model_config.allreduce_strategy) + self.aux_stream = aux_stream self.gate = Qwen3NextGate( hidden_size=self.hidden_dim, @@ -425,6 +427,11 @@ def __init__( dtype=config.torch_dtype, quant_config=None) + self.event_dict = { + key: torch.cuda.Event() + for key in [EventType.Main, EventType.MoeShared] + } + def forward( self, hidden_states: torch.Tensor, @@ -450,22 +457,33 @@ def forward( dim=0, sizes=all_rank_num_tokens) - router_logits = self.gate(hidden_states) - final_hidden_states = self.experts( - hidden_states, - router_logits, - all_rank_num_tokens=all_rank_num_tokens, - use_dp_padding=use_dp_padding, - do_finalize=do_finalize, - ) + def _compute_routed_output(): + router_logits = self.gate(hidden_states) + final_hidden_states = self.experts( + hidden_states, + router_logits, + all_rank_num_tokens=all_rank_num_tokens, + use_dp_padding=use_dp_padding, + do_finalize=do_finalize, + ) + return final_hidden_states + def _compute_shared_output(): + shared_expert_output = self.shared_expert(hidden_states) + shared_expert_output = F.sigmoid( + self.shared_expert_gate(hidden_states)) * shared_expert_output + return shared_expert_output + + final_hidden_states, shared_expert_output = maybe_execute_in_parallel( + _compute_routed_output, + _compute_shared_output, + self.event_dict[EventType.Main], + self.event_dict[EventType.MoeShared], + self.aux_stream, + ) if not do_finalize: return final_hidden_states - shared_expert_output = self.shared_expert(hidden_states) - shared_expert_output = F.sigmoid( - self.shared_expert_gate(hidden_states)) * shared_expert_output - final_hidden_states = final_hidden_states + shared_expert_output if not self.enable_attention_dp and self.mapping.tp_size > 1: @@ -543,22 +561,21 @@ def fused_qkvzba_split_reshape_cat( ): batch, seq_len = mixed_qkvz.shape[0], 1 qkv_dim_t = num_heads_qk * head_qk * 2 + num_heads_v * head_v - mixed_qkv = torch.empty( - [batch * seq_len, qkv_dim_t], - dtype=mixed_qkvz.dtype, - device=mixed_qkvz.device, - ) - z = torch.empty( - [batch * seq_len, num_heads_v, head_v], - dtype=mixed_qkvz.dtype, - device=mixed_qkvz.device, - ) - b = torch.empty( - [batch * seq_len, num_heads_v], - dtype=mixed_ba.dtype, - device=mixed_ba.device, - ) - a = torch.empty_like(b) + batch_seq = batch * seq_len + + # Directly allocate output tensors in their final shapes (no intermediate buffers) + mixed_qkv = torch.empty((batch_seq, qkv_dim_t), + dtype=mixed_qkvz.dtype, + device=mixed_qkvz.device) + z = torch.empty((batch_seq, num_heads_v, head_v), + dtype=mixed_qkvz.dtype, + device=mixed_qkvz.device) + b = torch.empty((batch_seq, num_heads_v), + dtype=mixed_ba.dtype, + device=mixed_ba.device) + a = torch.empty((batch_seq, num_heads_v), + dtype=mixed_ba.dtype, + device=mixed_ba.device) grid = (batch * seq_len, num_heads_qk) fused_qkvzba_split_reshape_cat_kernel[grid]( mixed_qkv, @@ -765,43 +782,42 @@ def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): """ Derives `query`, `key` and `value` tensors from `mixed_qkvzba`. """ - new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + ( - self.num_k_heads // self.attn_tp_size, - (self.head_k_dim + self.head_k_dim + - (self.head_v_dim + self.head_v_dim) * self.num_v_heads // - self.num_k_heads), - ) - new_tensor_shape_ba = mixed_ba.size()[:-1] + ( - self.num_k_heads // self.attn_tp_size, - 2 * self.num_v_heads // self.num_k_heads, - ) - - mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz) - mixed_ba = mixed_ba.view(*new_tensor_shape_ba) - - split_arg_list_qkvz = [ - self.head_k_dim, - self.head_k_dim, - (self.num_v_heads // self.num_k_heads * self.head_v_dim), - (self.num_v_heads // self.num_k_heads * self.head_v_dim), - ] - split_arg_list_ba = [ - self.num_v_heads // self.num_k_heads, - self.num_v_heads // self.num_k_heads, - ] - - # [b, sq, ng, (hn + hn + np/ng * hn + np/ng + np/ng)] - # --> [b, sq, ng, hn], [b, sq, ng, hn], [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng], [b, sq, ng, np/ng] - (query, key, value, z) = torch.split(mixed_qkvz, - split_arg_list_qkvz, - dim=2) - (b, a) = torch.split(mixed_ba, split_arg_list_ba, dim=2) - - # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn] - value = value.reshape(value.size(0), -1, self.head_v_dim) - z = z.reshape(z.size(0), -1, self.head_v_dim) - b = b.reshape(b.size(0), self.num_v_heads // self.attn_tp_size) - a = a.reshape(a.size(0), self.num_v_heads // self.attn_tp_size) + batch_size = mixed_qkvz.size(0) + num_k_heads_local = self.num_k_heads // self.attn_tp_size + num_v_heads_local = self.num_v_heads // self.attn_tp_size + heads_ratio = self.num_v_heads // self.num_k_heads + + # Reshape qkvz: [b, d] -> [b, ng, (2*hk + 2*np/ng*hv)] + qkvz_dim_per_head = (self.head_k_dim * 2 + + self.head_v_dim * heads_ratio * 2) + mixed_qkvz = mixed_qkvz.view(batch_size, num_k_heads_local, + qkvz_dim_per_head) + + # Reshape ba: [b, d] -> [b, ng, 2*np/ng] + mixed_ba = mixed_ba.view(batch_size, num_k_heads_local, heads_ratio * 2) + + # Direct slicing instead of torch.split for better performance + # Compute split boundaries once + q_end = self.head_k_dim + k_end = q_end + self.head_k_dim + v_end = k_end + heads_ratio * self.head_v_dim + z_end = v_end + heads_ratio * self.head_v_dim + + # Slice qkvz components: [b, ng, dim] -> individual components + query = mixed_qkvz[..., :q_end] + key = mixed_qkvz[..., q_end:k_end] + + # Optimize: Use view (zero-copy) instead of reshape for contiguous slices + # Layout: [v_concat | z_concat], need to reshape each separately + value = mixed_qkvz[..., k_end:v_end].view(batch_size, num_v_heads_local, + self.head_v_dim) + z = mixed_qkvz[..., v_end:z_end].view(batch_size, num_v_heads_local, + self.head_v_dim) + + # Slice ba components: [b, ng, 2*np/ng] -> [b, np] each + # Optimize: Use view instead of reshape (zero-copy for contiguous data) + b = mixed_ba[..., :heads_ratio].view(batch_size, num_v_heads_local) + a = mixed_ba[..., heads_ratio:].view(batch_size, num_v_heads_local) return query, key, value, z, b, a @@ -817,7 +833,6 @@ def forward_decode( a = kwargs["a"] b = kwargs["b"] cache_indices = kwargs["cache_indices"] - query_start_loc = torch.arange(0, num_decodes + 1, device=cu_seqlens.device).to(torch.long) @@ -831,15 +846,11 @@ def forward_decode( conv_state_indices=cache_indices, ) - query, key, value = torch.split( - mixed_qkv, - [ - self.key_dim // self.attn_tp_size, - self.key_dim // self.attn_tp_size, - self.value_dim // self.attn_tp_size, - ], - dim=-1, - ) + # Direct slicing instead of torch.split for better performance + key_size = self.key_dim // self.attn_tp_size + query = mixed_qkv[..., :key_size] + key = mixed_qkv[..., key_size:key_size * 2] + value = mixed_qkv[..., key_size * 2:] # Reshape from [l, h*d] to [1, l, h, d] seq_len = query.shape[0] num_heads = query.shape[1] // self.head_k_dim @@ -925,8 +936,7 @@ def forward_extend( conv_states=conv_states_to_use, has_initial_state=has_initial_states, cache_indices=cache_indices, - query_start_loc=query_start_loc, - ).transpose(0, 1) + query_start_loc=query_start_loc).transpose(0, 1) key_split_dim = self.key_dim // self.attn_tp_size value_split_dim = self.value_dim // self.attn_tp_size @@ -1024,9 +1034,8 @@ def forward( projected_states_qkvz = self.in_proj_qkvz(hidden_states) projected_states_ba = self.in_proj_ba(hidden_states) - query, key, value, z, b, a = self.fix_query_key_value_ordering( - projected_states_qkvz, projected_states_ba) + # Use fused kernel when possible to avoid elementwise ops if self.num_v_heads // self.num_k_heads in [1, 2, 4]: # and is_cuda_graph: mixed_qkv, z, b, a = fused_qkvzba_split_reshape_cat( @@ -1060,17 +1069,11 @@ def forward( "num_prefill": num_prefills, "num_decode": num_decodes, } - - new_implementation = True - if new_implementation: - if num_prefills > 0: - attn_out = self.forward_extend(conv_states, ssm_states, - **kwargs) - else: - attn_out = self.forward_decode(conv_states, ssm_states, - num_decodes, - mamba_metadata.cu_seqlens, - **kwargs) + if num_prefills > 0: + attn_out = self.forward_extend(conv_states, ssm_states, **kwargs) + else: + attn_out = self.forward_decode(conv_states, ssm_states, num_decodes, + mamba_metadata.cu_seqlens, **kwargs) z_shape_og = z.shape # reshape input data into 2D tensor @@ -1125,7 +1128,7 @@ def __init__( "TRTLLM_QWEN3_EAGER_FUSION_DISABLED", "1") == "0" self.enable_fusion &= not self.enable_attention_dp - self.mapping.has_tp() + # has_tp = self.mapping.has_tp() has_pp = self.mapping.has_pp() # self.fusion_config.PRE_MOE_FUSION = self.enable_fusion and has_tp @@ -1284,7 +1287,7 @@ def __init__(self, model_config: ModelConfig[Qwen3NextConfig], "TRTLLM_QWEN3_EAGER_FUSION_DISABLED", "0") == "0" self.enable_fusion &= not self.enable_attention_dp - self.mapping.has_tp() + # has_tp = self.mapping.has_tp() has_pp = self.mapping.has_pp() # self.fusion_config.PRE_MOE_FUSION = self.enable_fusion and has_tp diff --git a/tensorrt_llm/_torch/modules/fla/chunk.py b/tensorrt_llm/_torch/modules/fla/chunk.py index 6f908f74077..9fc5e322374 100644 --- a/tensorrt_llm/_torch/modules/fla/chunk.py +++ b/tensorrt_llm/_torch/modules/fla/chunk.py @@ -90,8 +90,6 @@ def forward( cu_seqlens: Optional[torch.LongTensor] = None, use_qk_l2norm_in_kernel: bool = False, ): - pass - if use_qk_l2norm_in_kernel: q = l2norm_fwd(q) k = l2norm_fwd(k)