Skip to content

Support non-contiguous KV cache in TRTLLM fp8 dequant kernel#36867

Merged
pavanimajety merged 4 commits intovllm-project:mainfrom
CentML:make_dequant_noncont
Mar 17, 2026
Merged

Support non-contiguous KV cache in TRTLLM fp8 dequant kernel#36867
pavanimajety merged 4 commits intovllm-project:mainfrom
CentML:make_dequant_noncont

Conversation

@vadiklyutiy
Copy link
Copy Markdown
Collaborator

@vadiklyutiy vadiklyutiy commented Mar 12, 2026

Summary

Fix the trtllm_prefill_attn_kvfp8_dequant Triton kernel to support non-contiguous KV cache tensors (e.g. cross-layer unified allocation used by KV offloading).

The kernel previously computed flat pointer offsets from tensor shape, assuming contiguity. With cross-layer KV caches the per-layer view is non-contiguous (strides skip over other layers), causing incorrect memory reads. Now the kernel uses actual tensor strides for the page, K/V, and head dimensions.

Related: vllm-project/vllm#34158 — relaxes the same assertion and moves the contiguity check closer to the dequant kernel. Our change goes further by fixing the dequant kernel itself to work with non-contiguous inputs.

Test

New standalone test: tests/kernels/attention/test_trtllm_kvfp8_dequant.py

python -m pytest tests/kernels/attention/test_trtllm_kvfp8_dequant.py -v

71 tests total (64 parametrized + 7 corner cases), all passing:

  • Parametrized matrix: num_kv_heads={1,8}, head_size={64,128}, block_size={16,32}, batch_size={1,4}, num_pages_per_seq={3,8}, contiguous={True,False}
  • Corner cases: zero/padding pages in block tables, all-zero block tables, different K/V scales, single page per sequence, large page indices (32K), large block_size (64), cross-layer with 36 layers (real gpt-oss-120b pattern)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for non-contiguous KV cache tensors in the trtllm_prefill_attn_kvfp8_dequant Triton kernel. The kernel is updated to use tensor strides for memory access, correctly handling layouts like those from cross-layer unified allocation. A comprehensive new test suite is also added, which thoroughly validates the changes against a reference implementation for both contiguous and non-contiguous cases, including numerous corner cases. The implementation appears correct and robust, and the tests provide strong confidence in the fix.

@pavanimajety
Copy link
Copy Markdown
Collaborator

@Etelis What do you think of this fix? Seems reasonable to me to remove the stride check for kv_cache throughout

@Etelis
Copy link
Copy Markdown
Contributor

Etelis commented Mar 13, 2026

Looks good to me — using actual strides in the dequant kernel is the right fix. With this merged I can drop the is_strictly_contiguous assertion I kept for the FP8 path in #34158 and replace it with the same minimal stride check you have here.

@pavanimajety pavanimajety added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 13, 2026
Copy link
Copy Markdown
Collaborator

@pavanimajety pavanimajety left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

vadiklyutiy and others added 2 commits March 14, 2026 17:23
Co-authored-by: Pavani Majety <pavanimajety@gmail.com>
Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com>
@mergify
Copy link
Copy Markdown

mergify bot commented Mar 14, 2026

Hi @vadiklyutiy, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
@vadiklyutiy
Copy link
Copy Markdown
Collaborator Author

vadiklyutiy commented Mar 16, 2026

if there are no further comments could we merge this PR?

@pavanimajety pavanimajety merged commit 6c1cfba into vllm-project:main Mar 17, 2026
58 of 59 checks passed
@github-project-automation github-project-automation bot moved this from Ready to Done in NVIDIA Mar 17, 2026
Lucaskabela pushed a commit to Lucaskabela/vllm that referenced this pull request Mar 17, 2026
…oject#36867)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com>
Co-authored-by: Pavani Majety <pavanimajety@gmail.com>
andylolu2 pushed a commit to andylolu2/vllm that referenced this pull request Mar 18, 2026
…oject#36867)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com>
Co-authored-by: Pavani Majety <pavanimajety@gmail.com>
wendyliu235 pushed a commit to wendyliu235/vllm-public that referenced this pull request Mar 18, 2026
…oject#36867)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com>
Co-authored-by: Pavani Majety <pavanimajety@gmail.com>
fxdawnn pushed a commit to fxdawnn/vllm that referenced this pull request Mar 19, 2026
…oject#36867)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com>
Co-authored-by: Pavani Majety <pavanimajety@gmail.com>
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 27, 2026
…oject#36867)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com>
Co-authored-by: Pavani Majety <pavanimajety@gmail.com>
Monishver11 pushed a commit to Monishver11/vllm that referenced this pull request Mar 27, 2026
…oject#36867)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com>
Co-authored-by: Pavani Majety <pavanimajety@gmail.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
JiantaoXu pushed a commit to JiantaoXu/vllm that referenced this pull request Mar 28, 2026
…oject#36867)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com>
Co-authored-by: Pavani Majety <pavanimajety@gmail.com>
vrdn-23 pushed a commit to vrdn-23/vllm that referenced this pull request Mar 30, 2026
…oject#36867)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com>
Co-authored-by: Pavani Majety <pavanimajety@gmail.com>
Signed-off-by: Vinay Damodaran <vrdn@hey.com>
EricccYang pushed a commit to EricccYang/vllm that referenced this pull request Apr 1, 2026
…oject#36867)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com>
Co-authored-by: Pavani Majety <pavanimajety@gmail.com>
Signed-off-by: EricccYang <yangyang4991@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nvidia ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

3 participants