Skip to content

[Bugfix] Relax TRTLLM KV cache contiguity assertion for cross-layer layout#34158

Merged
mgoin merged 11 commits intovllm-project:mainfrom
Etelis:relax-trtllm-contiguity-assertion
Mar 16, 2026
Merged

[Bugfix] Relax TRTLLM KV cache contiguity assertion for cross-layer layout#34158
mgoin merged 11 commits intovllm-project:mainfrom
Etelis:relax-trtllm-contiguity-assertion

Conversation

@Etelis
Copy link
Copy Markdown
Contributor

@Etelis Etelis commented Feb 9, 2026

Fixes #33572

The is_strictly_contiguous assertion on kv_cache_permute rejects the non-canonical strides views, and #33192 worked around it by just disabling TRTLLM entirely when KV transfer is enabled.

I dug into the FlashInfer kernel source to check if full-tensor contiguity is actually needed. Turns out the TRTLLM kernels in csrc/trtllm_fmha_kernel_launcher.cu read actual tensor strides:

int kv_stride_keys_values = key_cache.stride(-2);
int kv_stride_heads       = key_cache.stride(-3);
int kv_stride_batch       = key_cache.stride(0);

So the kernel handles non-contiguous layouts fine — the assertion was a vLLM-side guard,. The one place that needs contiguity is the Triton FP8 dequant kernel (_trtllm_prefill_attn_kvfp8_dequant), which computes strides from shape. Moved the assertion there.

Testing

Ran on B200 (Blackwell, SM100) with vLLM 0.15.1, FlashInfer 0.6.1, Llama-3.2-1B-Instruct.

test_cpu_offloading[FLASHINFER-48]: PASSED
test_flashinfer_trtllm_attention: 224 passed, 16 skipped

Using TRTLLM prefill attention (auto-detected).
Using HND KV cache layout for FLASHINFER backend.
Allocating a cross layer KV cache of shape (167623, 2, 8, 16, 16, 64)

Running tests: 100% 10/10 [00:02, 3.88it/s]
Average times:
    Cold: 69.95ms
    GPU hit: 19.26ms
    CPU hit: 40.50ms
PASSED (83.24s)

…ayout

The TRTLLM FlashInfer kernels (trtllm_paged_attention_decode/context)
use actual tensor strides via key_cache.stride() rather than computing
strides from shape. This means they natively support non-contiguous KV
cache tensors, such as those produced by cross-layer KV cache layout
used for CPU KV offloading.

Remove the overly conservative is_strictly_contiguous assertion on
kv_cache_permute and the blanket TRTLLM disable when KV transfer is
enabled. Move the contiguity assertion into the FP8 dequant branch
only, where the Triton kernel genuinely computes strides from shape.

Signed-off-by: Itay Etelis <itayetelis@gmail.com>

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
@mergify mergify bot added nvidia v1 bug Something isn't working labels Feb 9, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request relaxes the contiguity assertion for the TRTLLM KV cache to allow its use with cross-layer layouts, which can result from KV cache transfer. The change correctly removes the assertion from the general prefill and decode paths and moves it to the specific FP8 dequantization kernel path where contiguity is strictly required.

My review identifies one potential issue: while the logic to disable TRTLLM for decode with KV transfer is removed, a similar logic for prefill seems to have been missed, which would prevent the fix from being fully effective. I've added a comment with a suggestion to address this.

Copy link
Copy Markdown
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems reasonable to me given your investigation cc @vadiklyutiy @pavanimajety to confirm

Remove the guard that disabled TRTLLM for prefill when KV transfer
is enabled. The TRTLLM kernels read actual tensor strides and handle
non-contiguous layouts correctly — the contiguity assertion was already
moved to the FP8 dequant path where it is truly needed. The decode
path was fixed in the previous commit but the prefill path was missed.

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
@mergify
Copy link
Copy Markdown

mergify bot commented Mar 10, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @Etelis.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 10, 2026
Both sides removed the _kv_transfer_enabled TRTLLM disable guard -
upstream via revert of vllm-project#33192, our branch via the bugfix commit.
Trivial resolution: keep the blank line for readability.

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
@mergify mergify bot removed the needs-rebase label Mar 10, 2026
@vadiklyutiy
Copy link
Copy Markdown
Collaborator

I introduced is_strictly_contiguous check in #32008 and #32417.
There was a problem in degenerate strides

Shape: torch.Size([1, 32, 128])
Stride: (4608, 128, 1)

4608 != 32*128

So, the quote from PR description

I dug into the FlashInfer kernel source to check if full-tensor contiguity is actually needed. Turns out the TRTLLM kernels in csrc/trtllm_fmha_kernel_launcher.cu read actual tensor strides:

int kv_stride_keys_values = key_cache.stride(-2);
int kv_stride_heads       = key_cache.stride(-3);
int kv_stride_batch       = key_cache.stride(0);

cause a fail with this degenerate strides - kv_stride_batch will be incorrect.

Want to note that is_strictly_contiguous actually check 2 things:

  • is_contiguous
  • strides are canonical

If we are talking only about remove is_contiguous, I'd propose to break is_strictly_contiguous by 2 part and left "strides are canonical" check in the same place.

@vadiklyutiy
Copy link
Copy Markdown
Collaborator

I did dig/think a bit more about the topic.

Fully disabling TRTLLM attn when using kv-cache offloading (implemented in #33192) is really expensive workaround - TRTLLM attn is fastest.

This PR is right direction but let fully solve the problem - modify trtllm_prefill_attn_kvfp8_dequant to work with non-contiguous kv-cache.

@vadiklyutiy
Copy link
Copy Markdown
Collaborator

I did dig/think a bit more about the topic.

Fully disabling TRTLLM attn when using kv-cache offloading (implemented in #33192) is really expensive workaround - TRTLLM attn is fastest.

This PR is right direction but let fully solve the problem - modify trtllm_prefill_attn_kvfp8_dequant to work with non-contiguous kv-cache.

I created #36867 that add support of non-contiguous kv-cache in trtllm_prefill_attn_kvfp8_dequant. If it ok, we can fully remove is_strictly_contiguous assert for kv-cache

@pavanimajety
Copy link
Copy Markdown
Collaborator

pavanimajety commented Mar 13, 2026

@Etelis Just FYI on what happens if the strides are not canonical flashinfer-ai/flashinfer#2232 (comment)

I vote to keep the stride check like Vadim has in #36867

Etelis and others added 2 commits March 13, 2026 21:15
… for TRTLLM kv_cache

Replace is_strictly_contiguous(kv_cache_permute) with a targeted check
that only requires the inner (block_size, head_size) dims to be contiguous.
This allows non-contiguous outer dims from cross-layer unified allocation
while still catching degenerate strides that break TMA descriptors
(flashinfer-ai/flashinfer#2232).

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
@Etelis
Copy link
Copy Markdown
Contributor Author

Etelis commented Mar 13, 2026

Thanks @vadiklyutiy — your analysis of the degenerate stride issue and the kernel fix in #36867 are great. I've updated this PR to use a targeted canonical stride check (matching your approach) instead of the blanket contiguity assertion.

Copy link
Copy Markdown
Collaborator

@pavanimajety pavanimajety left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Etelis!

@github-project-automation github-project-automation bot moved this to Ready in NVIDIA Mar 13, 2026
@pavanimajety pavanimajety added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 13, 2026
@Etelis Etelis requested a review from mgoin March 16, 2026 12:53
@mgoin mgoin merged commit 5ae685c into vllm-project:main Mar 16, 2026
58 checks passed
@github-project-automation github-project-automation bot moved this from Ready to Done in NVIDIA Mar 16, 2026
@vadiklyutiy
Copy link
Copy Markdown
Collaborator

@Etelis Could you clarify why did you keep checks for -1 and -2 dims? Shouldn't we check -3 as well?
This flashinfer-ai/flashinfer#2232 (comment) was about incorrectness of the stride[-3].

Lucaskabela pushed a commit to Lucaskabela/vllm that referenced this pull request Mar 17, 2026
…ayout (vllm-project#34158)

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: Itay Etelis <itay.etelis@ibm.com>
wendyliu235 pushed a commit to wendyliu235/vllm-public that referenced this pull request Mar 18, 2026
…ayout (vllm-project#34158)

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: Itay Etelis <itay.etelis@ibm.com>
fxdawnn pushed a commit to fxdawnn/vllm that referenced this pull request Mar 19, 2026
…ayout (vllm-project#34158)

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: Itay Etelis <itay.etelis@ibm.com>
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 27, 2026
…ayout (vllm-project#34158)

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: Itay Etelis <itay.etelis@ibm.com>
Monishver11 pushed a commit to Monishver11/vllm that referenced this pull request Mar 27, 2026
…ayout (vllm-project#34158)

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: Itay Etelis <itay.etelis@ibm.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
JiantaoXu pushed a commit to JiantaoXu/vllm that referenced this pull request Mar 28, 2026
…ayout (vllm-project#34158)

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: Itay Etelis <itay.etelis@ibm.com>
vrdn-23 pushed a commit to vrdn-23/vllm that referenced this pull request Mar 30, 2026
…ayout (vllm-project#34158)

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: Itay Etelis <itay.etelis@ibm.com>
Signed-off-by: Vinay Damodaran <vrdn@hey.com>
EricccYang pushed a commit to EricccYang/vllm that referenced this pull request Apr 1, 2026
…ayout (vllm-project#34158)

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: Itay Etelis <itay.etelis@ibm.com>
Signed-off-by: EricccYang <yangyang4991@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working nvidia ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

[Bug]: GPT-OSS with CPU KV cache offload break with FlashInfer

5 participants