88from tensorrt_llm ._utils import get_sm_version
99from tensorrt_llm .logger import logger
1010from tensorrt_llm .mapping import Mapping
11+ from tensorrt_llm .quantization .utils import fp4_utils
1112
1213from ..attention_backend import (AttentionInputType , AttentionMetadata ,
1314 FlashInferAttentionMetadata , TrtllmAttention ,
@@ -336,6 +337,21 @@ def _attn_impl(
336337 attention_sinks : Optional [torch .Tensor ] = None ,
337338 ):
338339
340+ padded_num_tokens = attn_metadata .padded_num_tokens
341+ num_tokens = attn_metadata .num_tokens
342+
343+ if padded_num_tokens is not None :
344+ assert q .shape [0 ] == padded_num_tokens
345+ q = q [:num_tokens , :]
346+ if k is not None :
347+ assert k .shape [0 ] == padded_num_tokens
348+ k = k [:num_tokens , :]
349+ if v is not None :
350+ assert v .shape [0 ] == padded_num_tokens
351+ v = v [:num_tokens , :]
352+ assert output is not None
353+ assert output_sf is None
354+
339355 out_scale = None
340356 out_scale_sf = None
341357 has_quant_scale = (self .o_proj .has_fp8_qdq or self .o_proj .has_nvfp4
@@ -366,14 +382,19 @@ def _attn_impl(
366382 attention_window_size = attention_window_size ,
367383 attention_mask_data = attention_mask_data ,
368384 enable_attn_nvfp4_output = enable_attn_nvfp4_output ,
369- output = output ,
385+ output = output [: num_tokens , :] if output is not None else None ,
370386 output_sf = output_sf ,
371387 attention_sinks = attention_sinks )
372388 if isinstance (attn_output , tuple ):
373389 assert len (
374390 attn_output
375391 ) == 2 , "attn_output should be a tuple of (output, output_sf)"
376392 return attn_output [0 ], attn_output [1 ]
393+ if output is not None and output .shape [0 ] != num_tokens :
394+ output [num_tokens :].fill_ (0 )
395+ if output_sf is not None and output_sf .shape [0 ] != fp4_utils .pad_up (
396+ num_tokens , 128 ):
397+ output_sf [fp4_utils .pad_up (num_tokens , 128 ):].fill_ (0 )
377398 return attn_output , None
378399
379400 def forward (
@@ -908,11 +929,10 @@ def create_output(self, hidden_states: torch.Tensor):
908929 return hidden_states .new_empty ([num_tokens , hidden_size ],
909930 dtype = hidden_states .dtype )
910931
911- def forward_impl (self ,
912- position_ids : Optional [torch .Tensor ],
932+ def forward_impl (self , position_ids : Optional [torch .Tensor ],
913933 hidden_states : torch .Tensor ,
914934 attn_metadata : AttentionMetadata ,
915- output : Optional [ torch .Tensor ] = None ) -> torch . Tensor :
935+ output : torch .Tensor ) -> None :
916936 """
917937 Forward pass for the MLA module.
918938
@@ -925,6 +945,18 @@ def forward_impl(self,
925945 Returns:
926946 torch.Tensor: The output tensor.
927947 """
948+ # split q, k, v into context and gen batches
949+ num_contexts = attn_metadata .num_contexts
950+ num_generations = attn_metadata .num_generations
951+ num_ctx_tokens = attn_metadata .num_ctx_tokens
952+ num_tokens = attn_metadata .num_tokens
953+ padded_num_tokens = attn_metadata .padded_num_tokens
954+
955+ if padded_num_tokens is not None :
956+ hidden_states = hidden_states [:num_tokens , ...]
957+ if position_ids is not None :
958+ position_ids = position_ids [:num_tokens , ...]
959+
928960 if self .is_lite :
929961 compressed_kv , k_pe = self .kv_a_proj_with_mqa (hidden_states ).split (
930962 [self .kv_lora_rank , self .qk_rope_head_dim ], - 1 )
@@ -952,15 +984,11 @@ def forward_impl(self,
952984 self .aux_stream ,
953985 )
954986
955- # split q, k, v into context and gen batches
956- num_contexts = attn_metadata .num_contexts
957- num_generations = attn_metadata .num_generations
958- num_ctx_tokens = attn_metadata .num_ctx_tokens
959- num_tokens = attn_metadata .num_tokens
960-
961987 assert q .shape [
962988 0 ] == num_tokens , f"Expect q.shape[0] to be { num_tokens } , but got { q .shape [0 ]} "
963989
990+ assert output is not None , "output must be provided"
991+
964992 if num_contexts > 0 :
965993 q_ctx = q [:num_ctx_tokens , ...]
966994 compressed_kv_ctx = compressed_kv [:num_ctx_tokens , ...]
@@ -970,17 +998,14 @@ def forward_impl(self,
970998 assert position_ids is not None
971999 k_pe_ctx = self .apply_rope (q_ctx , k_pe_ctx , position_ids )
9721000
973- attn_output_context = self .forward_context (
1001+ self .forward_context (
9741002 q_ctx ,
9751003 compressed_kv_ctx ,
9761004 k_pe_ctx ,
9771005 attn_metadata ,
1006+ output [:num_ctx_tokens , :],
9781007 latent_cache_ctx ,
979- output = output if num_generations == 0 else None )
980- if num_generations == 0 :
981- return attn_output_context
982- else :
983- attn_output_context = None
1008+ )
9841009
9851010 if num_generations > 0 :
9861011 q_gen = q [num_ctx_tokens :, ...]
@@ -991,39 +1016,17 @@ def forward_impl(self,
9911016 assert position_ids is not None
9921017 k_pe_gen = self .apply_rope (q_gen , k_pe_gen , position_ids )
9931018
994- attn_output_gen = self .forward_generation (
1019+ self .forward_generation (
9951020 q_gen ,
9961021 compressed_kv_gen ,
9971022 k_pe_gen ,
9981023 attn_metadata ,
1024+ output [num_ctx_tokens :num_tokens , :],
9991025 latent_cache_gen ,
1000- output = output if num_contexts == 0 else None )
1001- if num_contexts == 0 :
1002- return attn_output_gen
1003- else :
1004- attn_output_gen = None
1026+ )
10051027
1006- # release pytorch activation memory
1007- q = None
1008- compressed_kv = None
1009- k_pe = None
1010-
1011- assert attn_output_context is not None and attn_output_gen is not None
1012- assert (
1013- len (attn_output_context .shape ) == 2
1014- ), f"attn_output_context must be rank 2, not { len (attn_output_context .shape )} "
1015- assert (
1016- len (attn_output_gen .shape ) == 2
1017- ), f"attn_output_gen must be rank 2, not { len (attn_output_gen .shape )} "
1018- output = output if output is not None else torch .empty (
1019- (num_tokens , attn_output_context .shape [1 ]),
1020- dtype = attn_output_context .dtype ,
1021- device = attn_output_context .device )
1022- output [:attn_output_context .shape [0 ], :] = attn_output_context
1023- output [attn_output_context .shape [0 ]:, :] = attn_output_gen
1024- attn_output_context = None
1025- attn_output_gen = None
1026- return output
1028+ if padded_num_tokens is not None :
1029+ output [num_tokens :].fill_ (0 )
10271030
10281031 def _maybe_concat_qkv (self , q , k , v ):
10291032 if k is not None and v is not None and self .support_fused_qkv :
@@ -1032,13 +1035,14 @@ def _maybe_concat_qkv(self, q, k, v):
10321035 return q , k , v
10331036
10341037 def forward_context_default (
1035- self ,
1036- q : torch .Tensor ,
1037- compressed_kv : torch .Tensor ,
1038- k_pe : torch .Tensor ,
1039- attn_metadata : AttentionMetadata ,
1040- latent_cache : Optional [torch .Tensor ] = None ,
1041- output : Optional [torch .Tensor ] = None ) -> torch .Tensor :
1038+ self ,
1039+ q : torch .Tensor ,
1040+ compressed_kv : torch .Tensor ,
1041+ k_pe : torch .Tensor ,
1042+ attn_metadata : AttentionMetadata ,
1043+ output : torch .Tensor ,
1044+ latent_cache : Optional [torch .Tensor ] = None ,
1045+ ) -> torch .Tensor :
10421046 kv = self .kv_b_proj (compressed_kv )
10431047 k_nope , v = kv .split (
10441048 [
@@ -1080,7 +1084,7 @@ def forward_context_with_cached_kv(
10801084 q : torch .Tensor ,
10811085 latent_cache : torch .Tensor ,
10821086 attn_metadata : AttentionMetadata ,
1083- output : Optional [ torch .Tensor ] = None ,
1087+ output : torch .Tensor ,
10841088 ) -> torch .Tensor :
10851089 assert latent_cache is not None
10861090 trtllm_attention = cast (TrtllmAttention , self .mha )
@@ -1166,7 +1170,7 @@ def forward_context_with_chunked_prefill(
11661170 latent_cache : torch .
11671171 Tensor , # compressed_kv + k_pe [context_tokens, 1, lora_size + rope_size]
11681172 attn_metadata : TrtllmAttentionMetadata ,
1169- output : Optional [ torch .Tensor ] = None ,
1173+ output : torch .Tensor ,
11701174 ) -> torch .Tensor :
11711175 trtllm_attention = cast (TrtllmAttention , self .mha )
11721176 # apply RoPE, append compressed_kv + k_pe to paged kv cache and assign q_pe to q
@@ -1189,11 +1193,8 @@ def forward_context_with_chunked_prefill(
11891193 dtype = torch .float ,
11901194 device = 'cuda' ,
11911195 )
1192- if output is None :
1193- attn_output = q .new_empty (
1194- (q .size (0 ), self .num_heads * self .v_head_dim ), dtype = q .dtype )
1195- else :
1196- attn_output = output
1196+
1197+ attn_output = output
11971198 temp_attn_output = q .new_empty (
11981199 (q .size (0 ), self .num_heads * self .v_head_dim ), dtype = q .dtype )
11991200
@@ -1325,8 +1326,8 @@ def forward_context(
13251326 compressed_kv : torch .Tensor ,
13261327 k_pe : torch .Tensor ,
13271328 attn_metadata : AttentionMetadata ,
1329+ output : torch .Tensor ,
13281330 latent_cache : Optional [torch .Tensor ] = None ,
1329- output : Optional [torch .Tensor ] = None ,
13301331 ) -> torch .Tensor :
13311332 if isinstance (self .mha , TrtllmAttention ):
13321333 assert isinstance (attn_metadata , TrtllmAttentionMetadata )
@@ -1339,16 +1340,17 @@ def forward_context(
13391340 return self .forward_context_with_cached_kv (
13401341 q , latent_cache , attn_metadata , output )
13411342 return self .forward_context_default (q , compressed_kv , k_pe ,
1342- attn_metadata , latent_cache , output )
1343+ attn_metadata , output , latent_cache )
13431344
13441345 def forward_generation (
1345- self ,
1346- q : torch .Tensor ,
1347- compressed_kv : torch .Tensor ,
1348- k_pe : torch .Tensor ,
1349- attn_metadata : AttentionMetadata ,
1350- latent_cache : Optional [torch .Tensor ] = None ,
1351- output : Optional [torch .Tensor ] = None ) -> torch .Tensor :
1346+ self ,
1347+ q : torch .Tensor ,
1348+ compressed_kv : torch .Tensor ,
1349+ k_pe : torch .Tensor ,
1350+ attn_metadata : AttentionMetadata ,
1351+ output : torch .Tensor ,
1352+ latent_cache : Optional [torch .Tensor ] = None ,
1353+ ) -> torch .Tensor :
13521354 num_tokens = q .shape [0 ]
13531355 q_nope , q_pe = q .view ([- 1 , self .num_heads , self .qk_head_dim ]).split (
13541356 [self .qk_nope_head_dim , self .qk_rope_head_dim ], dim = - 1 )
@@ -1420,12 +1422,6 @@ def forward_generation(
14201422 attn_out_latent = attn_out_latent .view (
14211423 [- 1 , self .num_heads , self .kv_lora_rank ])
14221424
1423- # [seq, num_heads * v_head_dim]
1424- output = output if output is not None else torch .empty (
1425- [num_tokens , self .num_heads * self .v_head_dim ],
1426- dtype = attn_out_latent .dtype ,
1427- device = attn_out_latent .device )
1428-
14291425 attn_output = output .view ([num_tokens , self .num_heads , self .v_head_dim ])
14301426
14311427 if self .v_b_proj .dtype == torch .bfloat16 :
0 commit comments