Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 96 additions & 93 deletions tensorrt_llm/_torch/models/modeling_qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,10 @@
from ..modules.linear import Linear, TensorParallelMode
from ..modules.mamba.causal_conv1d import causal_conv1d_fn, causal_conv1d_update
from ..modules.mamba.layernorm_gated import RMSNorm as RMSNormGated
from ..modules.multi_stream_utils import maybe_execute_in_parallel
from ..modules.rms_norm import RMSNorm
from ..speculative import SpecMetadata
from ..utils import AuxStreamType
from ..utils import AuxStreamType, EventType
from .modeling_qwen3 import Qwen3Attention
from .modeling_speculative import SpecDecOneEngineForCausalLM
from .modeling_utils import DecoderModel, EagerFusionConfig, register_auto_model
Expand Down Expand Up @@ -387,6 +388,7 @@ def __init__(
self.mapping = model_config.mapping
self.allreduce = AllReduce(mapping=model_config.mapping,
strategy=model_config.allreduce_strategy)
self.aux_stream = aux_stream

self.gate = Qwen3NextGate(
hidden_size=self.hidden_dim,
Expand Down Expand Up @@ -425,6 +427,11 @@ def __init__(
dtype=config.torch_dtype,
quant_config=None)

self.event_dict = {
key: torch.cuda.Event()
for key in [EventType.Main, EventType.MoeShared]
}

def forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -450,22 +457,33 @@ def forward(
dim=0,
sizes=all_rank_num_tokens)

router_logits = self.gate(hidden_states)
final_hidden_states = self.experts(
hidden_states,
router_logits,
all_rank_num_tokens=all_rank_num_tokens,
use_dp_padding=use_dp_padding,
do_finalize=do_finalize,
)
def _compute_routed_output():
router_logits = self.gate(hidden_states)
final_hidden_states = self.experts(
hidden_states,
router_logits,
all_rank_num_tokens=all_rank_num_tokens,
use_dp_padding=use_dp_padding,
do_finalize=do_finalize,
)
return final_hidden_states

def _compute_shared_output():
shared_expert_output = self.shared_expert(hidden_states)
shared_expert_output = F.sigmoid(
self.shared_expert_gate(hidden_states)) * shared_expert_output
return shared_expert_output

final_hidden_states, shared_expert_output = maybe_execute_in_parallel(
_compute_routed_output,
_compute_shared_output,
self.event_dict[EventType.Main],
self.event_dict[EventType.MoeShared],
self.aux_stream,
)
if not do_finalize:
return final_hidden_states

shared_expert_output = self.shared_expert(hidden_states)
shared_expert_output = F.sigmoid(
self.shared_expert_gate(hidden_states)) * shared_expert_output

final_hidden_states = final_hidden_states + shared_expert_output

if not self.enable_attention_dp and self.mapping.tp_size > 1:
Expand Down Expand Up @@ -543,22 +561,21 @@ def fused_qkvzba_split_reshape_cat(
):
batch, seq_len = mixed_qkvz.shape[0], 1
qkv_dim_t = num_heads_qk * head_qk * 2 + num_heads_v * head_v
mixed_qkv = torch.empty(
[batch * seq_len, qkv_dim_t],
dtype=mixed_qkvz.dtype,
device=mixed_qkvz.device,
)
z = torch.empty(
[batch * seq_len, num_heads_v, head_v],
dtype=mixed_qkvz.dtype,
device=mixed_qkvz.device,
)
b = torch.empty(
[batch * seq_len, num_heads_v],
dtype=mixed_ba.dtype,
device=mixed_ba.device,
)
a = torch.empty_like(b)
batch_seq = batch * seq_len

# Directly allocate output tensors in their final shapes (no intermediate buffers)
mixed_qkv = torch.empty((batch_seq, qkv_dim_t),
dtype=mixed_qkvz.dtype,
device=mixed_qkvz.device)
z = torch.empty((batch_seq, num_heads_v, head_v),
dtype=mixed_qkvz.dtype,
device=mixed_qkvz.device)
b = torch.empty((batch_seq, num_heads_v),
dtype=mixed_ba.dtype,
device=mixed_ba.device)
a = torch.empty((batch_seq, num_heads_v),
dtype=mixed_ba.dtype,
device=mixed_ba.device)
grid = (batch * seq_len, num_heads_qk)
fused_qkvzba_split_reshape_cat_kernel[grid](
mixed_qkv,
Expand Down Expand Up @@ -765,43 +782,42 @@ def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba):
"""
Derives `query`, `key` and `value` tensors from `mixed_qkvzba`.
"""
new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + (
self.num_k_heads // self.attn_tp_size,
(self.head_k_dim + self.head_k_dim +
(self.head_v_dim + self.head_v_dim) * self.num_v_heads //
self.num_k_heads),
)
new_tensor_shape_ba = mixed_ba.size()[:-1] + (
self.num_k_heads // self.attn_tp_size,
2 * self.num_v_heads // self.num_k_heads,
)

mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz)
mixed_ba = mixed_ba.view(*new_tensor_shape_ba)

split_arg_list_qkvz = [
self.head_k_dim,
self.head_k_dim,
(self.num_v_heads // self.num_k_heads * self.head_v_dim),
(self.num_v_heads // self.num_k_heads * self.head_v_dim),
]
split_arg_list_ba = [
self.num_v_heads // self.num_k_heads,
self.num_v_heads // self.num_k_heads,
]

# [b, sq, ng, (hn + hn + np/ng * hn + np/ng + np/ng)]
# --> [b, sq, ng, hn], [b, sq, ng, hn], [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng], [b, sq, ng, np/ng]
(query, key, value, z) = torch.split(mixed_qkvz,
split_arg_list_qkvz,
dim=2)
(b, a) = torch.split(mixed_ba, split_arg_list_ba, dim=2)

# [b, sq, ng, np/ng * hn] -> [b, sq, np, hn]
value = value.reshape(value.size(0), -1, self.head_v_dim)
z = z.reshape(z.size(0), -1, self.head_v_dim)
b = b.reshape(b.size(0), self.num_v_heads // self.attn_tp_size)
a = a.reshape(a.size(0), self.num_v_heads // self.attn_tp_size)
batch_size = mixed_qkvz.size(0)
num_k_heads_local = self.num_k_heads // self.attn_tp_size
num_v_heads_local = self.num_v_heads // self.attn_tp_size
heads_ratio = self.num_v_heads // self.num_k_heads

# Reshape qkvz: [b, d] -> [b, ng, (2*hk + 2*np/ng*hv)]
qkvz_dim_per_head = (self.head_k_dim * 2 +
self.head_v_dim * heads_ratio * 2)
mixed_qkvz = mixed_qkvz.view(batch_size, num_k_heads_local,
qkvz_dim_per_head)

# Reshape ba: [b, d] -> [b, ng, 2*np/ng]
mixed_ba = mixed_ba.view(batch_size, num_k_heads_local, heads_ratio * 2)

# Direct slicing instead of torch.split for better performance
# Compute split boundaries once
q_end = self.head_k_dim
k_end = q_end + self.head_k_dim
v_end = k_end + heads_ratio * self.head_v_dim
z_end = v_end + heads_ratio * self.head_v_dim

# Slice qkvz components: [b, ng, dim] -> individual components
query = mixed_qkvz[..., :q_end]
key = mixed_qkvz[..., q_end:k_end]

# Optimize: Use view (zero-copy) instead of reshape for contiguous slices
# Layout: [v_concat | z_concat], need to reshape each separately
value = mixed_qkvz[..., k_end:v_end].view(batch_size, num_v_heads_local,
self.head_v_dim)
z = mixed_qkvz[..., v_end:z_end].view(batch_size, num_v_heads_local,
self.head_v_dim)

# Slice ba components: [b, ng, 2*np/ng] -> [b, np] each
# Optimize: Use view instead of reshape (zero-copy for contiguous data)
b = mixed_ba[..., :heads_ratio].view(batch_size, num_v_heads_local)
a = mixed_ba[..., heads_ratio:].view(batch_size, num_v_heads_local)

return query, key, value, z, b, a

Expand All @@ -817,7 +833,6 @@ def forward_decode(
a = kwargs["a"]
b = kwargs["b"]
cache_indices = kwargs["cache_indices"]

query_start_loc = torch.arange(0,
num_decodes + 1,
device=cu_seqlens.device).to(torch.long)
Expand All @@ -831,15 +846,11 @@ def forward_decode(
conv_state_indices=cache_indices,
)

query, key, value = torch.split(
mixed_qkv,
[
self.key_dim // self.attn_tp_size,
self.key_dim // self.attn_tp_size,
self.value_dim // self.attn_tp_size,
],
dim=-1,
)
# Direct slicing instead of torch.split for better performance
key_size = self.key_dim // self.attn_tp_size
query = mixed_qkv[..., :key_size]
key = mixed_qkv[..., key_size:key_size * 2]
value = mixed_qkv[..., key_size * 2:]
# Reshape from [l, h*d] to [1, l, h, d]
seq_len = query.shape[0]
num_heads = query.shape[1] // self.head_k_dim
Expand Down Expand Up @@ -925,8 +936,7 @@ def forward_extend(
conv_states=conv_states_to_use,
has_initial_state=has_initial_states,
cache_indices=cache_indices,
query_start_loc=query_start_loc,
).transpose(0, 1)
query_start_loc=query_start_loc).transpose(0, 1)

key_split_dim = self.key_dim // self.attn_tp_size
value_split_dim = self.value_dim // self.attn_tp_size
Expand Down Expand Up @@ -1024,9 +1034,8 @@ def forward(

projected_states_qkvz = self.in_proj_qkvz(hidden_states)
projected_states_ba = self.in_proj_ba(hidden_states)
query, key, value, z, b, a = self.fix_query_key_value_ordering(
projected_states_qkvz, projected_states_ba)

# Use fused kernel when possible to avoid elementwise ops
if self.num_v_heads // self.num_k_heads in [1, 2,
4]: # and is_cuda_graph:
mixed_qkv, z, b, a = fused_qkvzba_split_reshape_cat(
Expand Down Expand Up @@ -1060,17 +1069,11 @@ def forward(
"num_prefill": num_prefills,
"num_decode": num_decodes,
}

new_implementation = True
if new_implementation:
if num_prefills > 0:
attn_out = self.forward_extend(conv_states, ssm_states,
**kwargs)
else:
attn_out = self.forward_decode(conv_states, ssm_states,
num_decodes,
mamba_metadata.cu_seqlens,
**kwargs)
if num_prefills > 0:
attn_out = self.forward_extend(conv_states, ssm_states, **kwargs)
else:
attn_out = self.forward_decode(conv_states, ssm_states, num_decodes,
mamba_metadata.cu_seqlens, **kwargs)

z_shape_og = z.shape
# reshape input data into 2D tensor
Expand Down Expand Up @@ -1125,7 +1128,7 @@ def __init__(
"TRTLLM_QWEN3_EAGER_FUSION_DISABLED", "1") == "0"
self.enable_fusion &= not self.enable_attention_dp

self.mapping.has_tp()
# has_tp = self.mapping.has_tp()
has_pp = self.mapping.has_pp()

# self.fusion_config.PRE_MOE_FUSION = self.enable_fusion and has_tp
Expand Down Expand Up @@ -1284,7 +1287,7 @@ def __init__(self, model_config: ModelConfig[Qwen3NextConfig],
"TRTLLM_QWEN3_EAGER_FUSION_DISABLED", "0") == "0"
self.enable_fusion &= not self.enable_attention_dp

self.mapping.has_tp()
# has_tp = self.mapping.has_tp()
has_pp = self.mapping.has_pp()

# self.fusion_config.PRE_MOE_FUSION = self.enable_fusion and has_tp
Expand Down
2 changes: 0 additions & 2 deletions tensorrt_llm/_torch/modules/fla/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,6 @@ def forward(
cu_seqlens: Optional[torch.LongTensor] = None,
use_qk_l2norm_in_kernel: bool = False,
):
pass

if use_qk_l2norm_in_kernel:
q = l2norm_fwd(q)
k = l2norm_fwd(k)
Expand Down