Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 8 additions & 16 deletions tensorrt_llm/_torch/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,19 +338,13 @@ def _attn_impl(
output_sf: Optional[torch.Tensor] = None,
attention_sinks: Optional[torch.Tensor] = None,
):

padded_num_tokens = attn_metadata.padded_num_tokens
num_tokens = attn_metadata.num_tokens

if padded_num_tokens is not None:
assert q.shape[0] == padded_num_tokens
q = q[:num_tokens, :]
if k is not None:
assert k.shape[0] == padded_num_tokens
k = k[:num_tokens, :]
if v is not None:
assert v.shape[0] == padded_num_tokens
v = v[:num_tokens, :]
q = q[:num_tokens, :]
if k is not None:
k = k[:num_tokens, :]
if v is not None:
v = v[:num_tokens, :]

out_scale = None
out_scale_sf = None
Expand Down Expand Up @@ -954,12 +948,10 @@ def forward_impl(self, position_ids: Optional[torch.Tensor],
num_generations = attn_metadata.num_generations
num_ctx_tokens = attn_metadata.num_ctx_tokens
num_tokens = attn_metadata.num_tokens
padded_num_tokens = attn_metadata.padded_num_tokens

if padded_num_tokens is not None:
hidden_states = hidden_states[:num_tokens, ...]
if position_ids is not None:
position_ids = position_ids[:num_tokens, ...]
hidden_states = hidden_states[:num_tokens, ...]
if position_ids is not None:
position_ids = position_ids[..., :num_tokens]

if self.is_lite:
compressed_kv, k_pe = self.kv_a_proj_with_mqa(hidden_states).split(
Expand Down
1 change: 0 additions & 1 deletion tests/integration/test_lists/waives.txt
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,6 @@ test_e2e.py::test_ptp_quickstart_advanced_8gpus_chunked_prefill_sq_22k[Llama-4-M
test_e2e.py::test_ptp_quickstart_advanced_8gpus_chunked_prefill_sq_22k[Llama-4-Scout-17B-16E-Instruct-FP8-llama4-models/Llama-4-Scout-17B-16E-Instruct-FP8-True] SKIP (https://nvbugs/5481094)
test_e2e.py::test_ptp_quickstart_multimodal_2gpu[Phi-4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct] SKIP (https://nvbugs/5480415)
accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_fp8_eagle3[tp8-torch_compile=False] SKIP (https://nvbugs/5483534)
accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_fp8_eagle3[tp8-torch_compile=True] SKIP (https://nvbugs/5483615)
accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_nixl_backend SKIP (https://nvbugs/5448748)
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_nixl_backend SKIP (https://nvbugs/5448748)
accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_w4a8_mxfp4[fp8-latency-CUTLASS] SKIP (https://nvbugs/5483913)
Expand Down