Skip to content

Commit de6abc9

Browse files
committed
[https://nvbugs/5483615][fix] Remove unnecessary assertion to let main model and MTP work at the same time
Signed-off-by: Jin Li <[email protected]>
1 parent 9a4f606 commit de6abc9

File tree

2 files changed

+8
-17
lines changed

2 files changed

+8
-17
lines changed

tensorrt_llm/_torch/modules/attention.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -338,19 +338,13 @@ def _attn_impl(
338338
output_sf: Optional[torch.Tensor] = None,
339339
attention_sinks: Optional[torch.Tensor] = None,
340340
):
341-
342-
padded_num_tokens = attn_metadata.padded_num_tokens
343341
num_tokens = attn_metadata.num_tokens
344342

345-
if padded_num_tokens is not None:
346-
assert q.shape[0] == padded_num_tokens
347-
q = q[:num_tokens, :]
348-
if k is not None:
349-
assert k.shape[0] == padded_num_tokens
350-
k = k[:num_tokens, :]
351-
if v is not None:
352-
assert v.shape[0] == padded_num_tokens
353-
v = v[:num_tokens, :]
343+
q = q[:num_tokens, :]
344+
if k is not None:
345+
k = k[:num_tokens, :]
346+
if v is not None:
347+
v = v[:num_tokens, :]
354348

355349
out_scale = None
356350
out_scale_sf = None
@@ -954,12 +948,10 @@ def forward_impl(self, position_ids: Optional[torch.Tensor],
954948
num_generations = attn_metadata.num_generations
955949
num_ctx_tokens = attn_metadata.num_ctx_tokens
956950
num_tokens = attn_metadata.num_tokens
957-
padded_num_tokens = attn_metadata.padded_num_tokens
958951

959-
if padded_num_tokens is not None:
960-
hidden_states = hidden_states[:num_tokens, ...]
961-
if position_ids is not None:
962-
position_ids = position_ids[:num_tokens, ...]
952+
hidden_states = hidden_states[:num_tokens, ...]
953+
if position_ids is not None:
954+
position_ids = position_ids[:num_tokens, ...]
963955

964956
if self.is_lite:
965957
compressed_kv, k_pe = self.kv_a_proj_with_mqa(hidden_states).split(

tests/integration/test_lists/waives.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,6 @@ test_e2e.py::test_ptp_quickstart_advanced_8gpus_chunked_prefill_sq_22k[Llama-4-M
336336
test_e2e.py::test_ptp_quickstart_advanced_8gpus_chunked_prefill_sq_22k[Llama-4-Scout-17B-16E-Instruct-FP8-llama4-models/Llama-4-Scout-17B-16E-Instruct-FP8-True] SKIP (https://nvbugs/5481094)
337337
test_e2e.py::test_ptp_quickstart_multimodal_2gpu[Phi-4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct] SKIP (https://nvbugs/5480415)
338338
accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_fp8_eagle3[tp8-torch_compile=False] SKIP (https://nvbugs/5483534)
339-
accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_fp8_eagle3[tp8-torch_compile=True] SKIP (https://nvbugs/5483615)
340339
accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_nixl_backend SKIP (https://nvbugs/5448748)
341340
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_nixl_backend SKIP (https://nvbugs/5448748)
342341
accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[fp8-latency-CUTLASS] SKIP (https://nvbugs/5483913)

0 commit comments

Comments
 (0)