[Bugfix] Relax TRTLLM KV cache contiguity assertion for cross-layer layout#34158
[Bugfix] Relax TRTLLM KV cache contiguity assertion for cross-layer layout#34158mgoin merged 11 commits intovllm-project:mainfrom
Conversation
…ayout The TRTLLM FlashInfer kernels (trtllm_paged_attention_decode/context) use actual tensor strides via key_cache.stride() rather than computing strides from shape. This means they natively support non-contiguous KV cache tensors, such as those produced by cross-layer KV cache layout used for CPU KV offloading. Remove the overly conservative is_strictly_contiguous assertion on kv_cache_permute and the blanket TRTLLM disable when KV transfer is enabled. Move the contiguity assertion into the FP8 dequant branch only, where the Triton kernel genuinely computes strides from shape. Signed-off-by: Itay Etelis <itayetelis@gmail.com> Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
There was a problem hiding this comment.
Code Review
This pull request relaxes the contiguity assertion for the TRTLLM KV cache to allow its use with cross-layer layouts, which can result from KV cache transfer. The change correctly removes the assertion from the general prefill and decode paths and moves it to the specific FP8 dequantization kernel path where contiguity is strictly required.
My review identifies one potential issue: while the logic to disable TRTLLM for decode with KV transfer is removed, a similar logic for prefill seems to have been missed, which would prevent the fix from being fully effective. I've added a comment with a suggestion to address this.
mgoin
left a comment
There was a problem hiding this comment.
This seems reasonable to me given your investigation cc @vadiklyutiy @pavanimajety to confirm
Remove the guard that disabled TRTLLM for prefill when KV transfer is enabled. The TRTLLM kernels read actual tensor strides and handle non-contiguous layouts correctly — the contiguity assertion was already moved to the FP8 dequant path where it is truly needed. The decode path was fixed in the previous commit but the prefill path was missed. Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
|
This pull request has merge conflicts that must be resolved before it can be |
Both sides removed the _kv_transfer_enabled TRTLLM disable guard - upstream via revert of vllm-project#33192, our branch via the bugfix commit. Trivial resolution: keep the blank line for readability. Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
|
I introduced
So, the quote from PR description
cause a fail with this degenerate strides - Want to note that
If we are talking only about remove |
|
I did dig/think a bit more about the topic. Fully disabling TRTLLM attn when using kv-cache offloading (implemented in #33192) is really expensive workaround - TRTLLM attn is fastest. This PR is right direction but let fully solve the problem - modify |
I created #36867 that add support of non-contiguous kv-cache in |
|
@Etelis Just FYI on what happens if the strides are not canonical flashinfer-ai/flashinfer#2232 (comment) I vote to keep the stride check like Vadim has in #36867 |
… for TRTLLM kv_cache Replace is_strictly_contiguous(kv_cache_permute) with a targeted check that only requires the inner (block_size, head_size) dims to be contiguous. This allows non-contiguous outer dims from cross-layer unified allocation while still catching degenerate strides that break TMA descriptors (flashinfer-ai/flashinfer#2232). Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
|
Thanks @vadiklyutiy — your analysis of the degenerate stride issue and the kernel fix in #36867 are great. I've updated this PR to use a targeted canonical stride check (matching your approach) instead of the blanket contiguity assertion. |
|
@Etelis Could you clarify why did you keep checks for -1 and -2 dims? Shouldn't we check -3 as well? |
…ayout (vllm-project#34158) Signed-off-by: Itay Etelis <itay.etelis@ibm.com> Co-authored-by: Itay Etelis <itay.etelis@ibm.com>
…ayout (vllm-project#34158) Signed-off-by: Itay Etelis <itay.etelis@ibm.com> Co-authored-by: Itay Etelis <itay.etelis@ibm.com>
…ayout (vllm-project#34158) Signed-off-by: Itay Etelis <itay.etelis@ibm.com> Co-authored-by: Itay Etelis <itay.etelis@ibm.com>
…ayout (vllm-project#34158) Signed-off-by: Itay Etelis <itay.etelis@ibm.com> Co-authored-by: Itay Etelis <itay.etelis@ibm.com>
…ayout (vllm-project#34158) Signed-off-by: Itay Etelis <itay.etelis@ibm.com> Co-authored-by: Itay Etelis <itay.etelis@ibm.com> Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
…ayout (vllm-project#34158) Signed-off-by: Itay Etelis <itay.etelis@ibm.com> Co-authored-by: Itay Etelis <itay.etelis@ibm.com>
…ayout (vllm-project#34158) Signed-off-by: Itay Etelis <itay.etelis@ibm.com> Co-authored-by: Itay Etelis <itay.etelis@ibm.com> Signed-off-by: Vinay Damodaran <vrdn@hey.com>
…ayout (vllm-project#34158) Signed-off-by: Itay Etelis <itay.etelis@ibm.com> Co-authored-by: Itay Etelis <itay.etelis@ibm.com> Signed-off-by: EricccYang <yangyang4991@gmail.com>
Fixes #33572
The
is_strictly_contiguousassertion onkv_cache_permuterejects the non-canonical strides views, and #33192 worked around it by just disabling TRTLLM entirely when KV transfer is enabled.I dug into the FlashInfer kernel source to check if full-tensor contiguity is actually needed. Turns out the TRTLLM kernels in
csrc/trtllm_fmha_kernel_launcher.curead actual tensor strides:So the kernel handles non-contiguous layouts fine — the assertion was a vLLM-side guard,. The one place that needs contiguity is the Triton FP8 dequant kernel (
_trtllm_prefill_attn_kvfp8_dequant), which computes strides from shape. Moved the assertion there.Testing
Ran on B200 (Blackwell, SM100) with vLLM 0.15.1, FlashInfer 0.6.1, Llama-3.2-1B-Instruct.
test_cpu_offloading[FLASHINFER-48]: PASSEDtest_flashinfer_trtllm_attention: 224 passed, 16 skipped