@@ -338,6 +338,19 @@ def _attn_impl(
338338 attention_sinks : Optional [torch .Tensor ] = None ,
339339 ):
340340
341+ padded_num_tokens = attn_metadata .padded_num_tokens
342+ num_tokens = attn_metadata .num_tokens
343+
344+ if padded_num_tokens is not None :
345+ assert q .shape [0 ] == padded_num_tokens
346+ q = q [:num_tokens , :]
347+ if k is not None :
348+ assert k .shape [0 ] == padded_num_tokens
349+ k = k [:num_tokens , :]
350+ if v is not None :
351+ assert v .shape [0 ] == padded_num_tokens
352+ v = v [:num_tokens , :]
353+
341354 out_scale = None
342355 out_scale_sf = None
343356 has_quant_scale = (self .o_proj .has_fp8_qdq or self .o_proj .has_nvfp4
@@ -368,7 +381,7 @@ def _attn_impl(
368381 attention_window_size = attention_window_size ,
369382 attention_mask_data = attention_mask_data ,
370383 enable_attn_nvfp4_output = enable_attn_nvfp4_output ,
371- output = output ,
384+ output = output [: num_tokens , :] if output is not None else None ,
372385 output_sf = output_sf ,
373386 attention_sinks = attention_sinks )
374387 if isinstance (attn_output , tuple ):
@@ -936,11 +949,10 @@ def create_output(self, hidden_states: torch.Tensor):
936949 return hidden_states .new_empty ([num_tokens , hidden_size ],
937950 dtype = hidden_states .dtype )
938951
939- def forward_impl (self ,
940- position_ids : Optional [torch .Tensor ],
952+ def forward_impl (self , position_ids : Optional [torch .Tensor ],
941953 hidden_states : torch .Tensor ,
942954 attn_metadata : AttentionMetadata ,
943- output : Optional [ torch .Tensor ] = None ) -> torch . Tensor :
955+ output : torch .Tensor ) -> None :
944956 """
945957 Forward pass for the MLA module.
946958
@@ -953,6 +965,18 @@ def forward_impl(self,
953965 Returns:
954966 torch.Tensor: The output tensor.
955967 """
968+ # split q, k, v into context and gen batches
969+ num_contexts = attn_metadata .num_contexts
970+ num_generations = attn_metadata .num_generations
971+ num_ctx_tokens = attn_metadata .num_ctx_tokens
972+ num_tokens = attn_metadata .num_tokens
973+ padded_num_tokens = attn_metadata .padded_num_tokens
974+
975+ if padded_num_tokens is not None :
976+ hidden_states = hidden_states [:num_tokens , ...]
977+ if position_ids is not None :
978+ position_ids = position_ids [:num_tokens , ...]
979+
956980 if self .is_lite :
957981 compressed_kv , k_pe = self .kv_a_proj_with_mqa (hidden_states ).split (
958982 [self .kv_lora_rank , self .qk_rope_head_dim ], - 1 )
@@ -980,15 +1004,11 @@ def forward_impl(self,
9801004 self .aux_stream ,
9811005 )
9821006
983- # split q, k, v into context and gen batches
984- num_contexts = attn_metadata .num_contexts
985- num_generations = attn_metadata .num_generations
986- num_ctx_tokens = attn_metadata .num_ctx_tokens
987- num_tokens = attn_metadata .num_tokens
988-
9891007 assert q .shape [
9901008 0 ] == num_tokens , f"Expect q.shape[0] to be { num_tokens } , but got { q .shape [0 ]} "
9911009
1010+ assert output is not None , "output must be provided"
1011+
9921012 if num_contexts > 0 :
9931013 q_ctx = q [:num_ctx_tokens , ...]
9941014 compressed_kv_ctx = compressed_kv [:num_ctx_tokens , ...]
@@ -998,17 +1018,14 @@ def forward_impl(self,
9981018 assert position_ids is not None
9991019 k_pe_ctx = self .apply_rope (q_ctx , k_pe_ctx , position_ids )
10001020
1001- attn_output_context = self .forward_context (
1021+ self .forward_context (
10021022 q_ctx ,
10031023 compressed_kv_ctx ,
10041024 k_pe_ctx ,
10051025 attn_metadata ,
1026+ output [:num_ctx_tokens , :],
10061027 latent_cache_ctx ,
1007- output = output if num_generations == 0 else None )
1008- if num_generations == 0 :
1009- return attn_output_context
1010- else :
1011- attn_output_context = None
1028+ )
10121029
10131030 if num_generations > 0 :
10141031 q_gen = q [num_ctx_tokens :, ...]
@@ -1019,48 +1036,24 @@ def forward_impl(self,
10191036 assert position_ids is not None
10201037 k_pe_gen = self .apply_rope (q_gen , k_pe_gen , position_ids )
10211038
1022- attn_output_gen = self .forward_generation (
1039+ self .forward_generation (
10231040 q_gen ,
10241041 compressed_kv_gen ,
10251042 k_pe_gen ,
10261043 attn_metadata ,
1044+ output [num_ctx_tokens :num_tokens , :],
10271045 latent_cache_gen ,
1028- output = output if num_contexts == 0 else None )
1029- if num_contexts == 0 :
1030- return attn_output_gen
1031- else :
1032- attn_output_gen = None
1033-
1034- # release pytorch activation memory
1035- q = None
1036- compressed_kv = None
1037- k_pe = None
1038-
1039- assert attn_output_context is not None and attn_output_gen is not None
1040- assert (
1041- len (attn_output_context .shape ) == 2
1042- ), f"attn_output_context must be rank 2, not { len (attn_output_context .shape )} "
1043- assert (
1044- len (attn_output_gen .shape ) == 2
1045- ), f"attn_output_gen must be rank 2, not { len (attn_output_gen .shape )} "
1046- output = output if output is not None else torch .empty (
1047- (num_tokens , attn_output_context .shape [1 ]),
1048- dtype = attn_output_context .dtype ,
1049- device = attn_output_context .device )
1050- output [:attn_output_context .shape [0 ], :] = attn_output_context
1051- output [attn_output_context .shape [0 ]:, :] = attn_output_gen
1052- attn_output_context = None
1053- attn_output_gen = None
1054- return output
1046+ )
10551047
10561048 def forward_context_default (
1057- self ,
1058- q : torch .Tensor ,
1059- compressed_kv : torch .Tensor ,
1060- k_pe : torch .Tensor ,
1061- attn_metadata : AttentionMetadata ,
1062- latent_cache : Optional [torch .Tensor ] = None ,
1063- output : Optional [torch .Tensor ] = None ) -> torch .Tensor :
1049+ self ,
1050+ q : torch .Tensor ,
1051+ compressed_kv : torch .Tensor ,
1052+ k_pe : torch .Tensor ,
1053+ attn_metadata : AttentionMetadata ,
1054+ output : torch .Tensor ,
1055+ latent_cache : Optional [torch .Tensor ] = None ,
1056+ ) -> torch .Tensor :
10641057 kv = self .kv_b_proj (compressed_kv )
10651058 k_nope , v = kv .split (
10661059 [
@@ -1099,7 +1092,7 @@ def forward_context_with_cached_kv(
10991092 q : torch .Tensor ,
11001093 latent_cache : torch .Tensor ,
11011094 attn_metadata : AttentionMetadata ,
1102- output : Optional [ torch .Tensor ] = None ,
1095+ output : torch .Tensor ,
11031096 ) -> torch .Tensor :
11041097 assert latent_cache is not None
11051098 trtllm_attention = cast (TrtllmAttention , self .mha )
@@ -1168,7 +1161,7 @@ def forward_context_with_chunked_prefill(
11681161 latent_cache : torch .
11691162 Tensor , # compressed_kv + k_pe [context_tokens, 1, lora_size + rope_size]
11701163 attn_metadata : TrtllmAttentionMetadata ,
1171- output : Optional [ torch .Tensor ] = None ,
1164+ output : torch .Tensor ,
11721165 ) -> torch .Tensor :
11731166 trtllm_attention = cast (TrtllmAttention , self .mha )
11741167 # apply RoPE, append compressed_kv + k_pe to paged kv cache and assign q_pe to q
@@ -1190,11 +1183,8 @@ def forward_context_with_chunked_prefill(
11901183 dtype = torch .float ,
11911184 device = 'cuda' ,
11921185 )
1193- if output is None :
1194- attn_output = q .new_empty (
1195- (q .size (0 ), self .num_heads * self .v_head_dim ), dtype = q .dtype )
1196- else :
1197- attn_output = output
1186+
1187+ attn_output = output
11981188 temp_attn_output = q .new_empty (
11991189 (q .size (0 ), self .num_heads * self .v_head_dim ), dtype = q .dtype )
12001190
@@ -1332,8 +1322,8 @@ def forward_context(
13321322 compressed_kv : torch .Tensor ,
13331323 k_pe : torch .Tensor ,
13341324 attn_metadata : AttentionMetadata ,
1325+ output : torch .Tensor ,
13351326 latent_cache : Optional [torch .Tensor ] = None ,
1336- output : Optional [torch .Tensor ] = None ,
13371327 ) -> torch .Tensor :
13381328 if isinstance (self .mha , TrtllmAttention ):
13391329 assert isinstance (attn_metadata , TrtllmAttentionMetadata )
@@ -1346,16 +1336,17 @@ def forward_context(
13461336 return self .forward_context_with_cached_kv (
13471337 q , latent_cache , attn_metadata , output )
13481338 return self .forward_context_default (q , compressed_kv , k_pe ,
1349- attn_metadata , latent_cache , output )
1339+ attn_metadata , output , latent_cache )
13501340
13511341 def forward_generation (
1352- self ,
1353- q : torch .Tensor ,
1354- compressed_kv : torch .Tensor ,
1355- k_pe : torch .Tensor ,
1356- attn_metadata : AttentionMetadata ,
1357- latent_cache : Optional [torch .Tensor ] = None ,
1358- output : Optional [torch .Tensor ] = None ) -> torch .Tensor :
1342+ self ,
1343+ q : torch .Tensor ,
1344+ compressed_kv : torch .Tensor ,
1345+ k_pe : torch .Tensor ,
1346+ attn_metadata : AttentionMetadata ,
1347+ output : torch .Tensor ,
1348+ latent_cache : Optional [torch .Tensor ] = None ,
1349+ ) -> torch .Tensor :
13591350 num_tokens = q .shape [0 ]
13601351 q_nope , q_pe = q .view ([- 1 , self .num_heads , self .qk_head_dim ]).split (
13611352 [self .qk_nope_head_dim , self .qk_rope_head_dim ], dim = - 1 )
@@ -1427,12 +1418,6 @@ def forward_generation(
14271418 attn_out_latent = attn_out_latent .view (
14281419 [- 1 , self .num_heads , self .kv_lora_rank ])
14291420
1430- # [seq, num_heads * v_head_dim]
1431- output = output if output is not None else torch .empty (
1432- [num_tokens , self .num_heads * self .v_head_dim ],
1433- dtype = attn_out_latent .dtype ,
1434- device = attn_out_latent .device )
1435-
14361421 attn_output = output .view ([num_tokens , self .num_heads , self .v_head_dim ])
14371422
14381423 if self .v_b_proj .dtype == torch .bfloat16 :
0 commit comments