Support non-contiguous KV cache in TRTLLM fp8 dequant kernel#36867
Support non-contiguous KV cache in TRTLLM fp8 dequant kernel#36867pavanimajety merged 4 commits intovllm-project:mainfrom
Conversation
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
There was a problem hiding this comment.
Code Review
This pull request introduces support for non-contiguous KV cache tensors in the trtllm_prefill_attn_kvfp8_dequant Triton kernel. The kernel is updated to use tensor strides for memory access, correctly handling layouts like those from cross-layer unified allocation. A comprehensive new test suite is also added, which thoroughly validates the changes against a reference implementation for both contiguous and non-contiguous cases, including numerous corner cases. The implementation appears correct and robust, and the tests provide strong confidence in the fix.
|
@Etelis What do you think of this fix? Seems reasonable to me to remove the stride check for kv_cache throughout |
|
Looks good to me — using actual strides in the dequant kernel is the right fix. With this merged I can drop the |
Co-authored-by: Pavani Majety <pavanimajety@gmail.com> Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com>
|
Hi @vadiklyutiy, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
|
if there are no further comments could we merge this PR? |
…oject#36867) Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com> Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com> Co-authored-by: Pavani Majety <pavanimajety@gmail.com>
…oject#36867) Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com> Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com> Co-authored-by: Pavani Majety <pavanimajety@gmail.com>
…oject#36867) Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com> Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com> Co-authored-by: Pavani Majety <pavanimajety@gmail.com>
…oject#36867) Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com> Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com> Co-authored-by: Pavani Majety <pavanimajety@gmail.com>
…oject#36867) Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com> Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com> Co-authored-by: Pavani Majety <pavanimajety@gmail.com>
…oject#36867) Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com> Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com> Co-authored-by: Pavani Majety <pavanimajety@gmail.com> Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
…oject#36867) Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com> Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com> Co-authored-by: Pavani Majety <pavanimajety@gmail.com>
…oject#36867) Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com> Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com> Co-authored-by: Pavani Majety <pavanimajety@gmail.com> Signed-off-by: Vinay Damodaran <vrdn@hey.com>
…oject#36867) Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com> Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com> Co-authored-by: Pavani Majety <pavanimajety@gmail.com> Signed-off-by: EricccYang <yangyang4991@gmail.com>
Summary
Fix the
trtllm_prefill_attn_kvfp8_dequantTriton kernel to support non-contiguous KV cache tensors (e.g. cross-layer unified allocation used by KV offloading).The kernel previously computed flat pointer offsets from tensor shape, assuming contiguity. With cross-layer KV caches the per-layer view is non-contiguous (strides skip over other layers), causing incorrect memory reads. Now the kernel uses actual tensor strides for the page, K/V, and head dimensions.
Related: vllm-project/vllm#34158 — relaxes the same assertion and moves the contiguity check closer to the dequant kernel. Our change goes further by fixing the dequant kernel itself to work with non-contiguous inputs.
Test
New standalone test:
tests/kernels/attention/test_trtllm_kvfp8_dequant.py71 tests total (64 parametrized + 7 corner cases), all passing:
num_kv_heads={1,8},head_size={64,128},block_size={16,32},batch_size={1,4},num_pages_per_seq={3,8},contiguous={True,False}