-
-
Notifications
You must be signed in to change notification settings - Fork 16.7k
[Bugfix] Fix degenerate KV cache stride causing TMA cudaErrorIllegalInstruction #40737
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
robertgshaw2-redhat
merged 10 commits into
vllm-project:main
from
the-david-oy:fix/tma-kv-head-stride-degenerate
May 3, 2026
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
a25ecac
[Bugfix] Fix degenerate KV cache stride causing TMA cudaErrorIllegalI…
5e7536b
Accurately set strides for non-contiguous tensors
the-david-oy d1863eb
Address review: clarify size=1 wording, revert incidental assert refo…
the-david-oy a19353e
Address review: replace reshape with canonicalize_singleton_dim_strid…
the-david-oy 3d3bbb5
Address review: replace kv stride assertions with canonicalize_single…
the-david-oy 1b9b963
Address review: split contiguous + canonicalize into sequential state…
the-david-oy 122283b
Merge branch 'main' into fix/tma-kv-head-stride-degenerate
the-david-oy 2a52c1d
Fix missing logger in flash_attn_diffkv and mypy sched_getaffinity
the-david-oy 9ad8883
Address review: shorten docstring, restore inner-dim assertions, add …
the-david-oy dbc2ca7
Merge branch 'main' into fix/tma-kv-head-stride-degenerate
vadiklyutiy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
162 changes: 162 additions & 0 deletions
162
tests/v1/attention/test_kv_head_stride_canonicalization.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,162 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| """Unit tests for canonicalize_singleton_dim_strides. | ||
|
|
||
| Background | ||
| ---------- | ||
| When num_kv_heads_per_rank == 1 (e.g. Qwen3.5-397B with TP=8 → 1 KV head | ||
| per rank), PyTorch's is_contiguous() returns True for *any* stride on the | ||
| size-1 dimension. The KV cache allocator can therefore produce a tensor | ||
| where that singleton dim has stride = 1 element (2 bytes for bf16) instead | ||
| of the canonical product-of-remaining-dims value. | ||
|
|
||
| CUDA TMA (used by FlashInfer XQA SM90 and Flash-Attention 3/4 on H100+) | ||
| requires all non-outermost strides to be multiples of 16 bytes. A 2-byte | ||
| stride triggers cudaErrorIllegalInstruction. | ||
|
|
||
| canonicalize_singleton_dim_strides() patches degenerate strides on all | ||
| size-1 dimensions via torch.as_strided — zero-copy. | ||
|
|
||
| The degenerate stride manifests at different positions in different backends: | ||
| - FlashInfer: stride(-3) after kv_cache.permute() → shape [..., 1, B, D] | ||
| - FlashAttention: stride(-2) after kv_cache.unbind(0) → shape [N, B, 1, D] | ||
| """ | ||
|
|
||
| import torch | ||
|
|
||
| from vllm.utils.torch_utils import canonicalize_singleton_dim_strides | ||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Helpers | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| def _inject_degenerate_stride(t: torch.Tensor, dim: int) -> torch.Tensor: | ||
| """Return a view of t with a degenerate (stride=1) on a size-1 dim.""" | ||
| assert t.shape[dim] == 1, f"dim {dim} must have size 1" | ||
| strides = list(t.stride()) | ||
| strides[dim] = 1 # inject the bug | ||
| return t.as_strided(t.shape, strides) | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Tests: canonicalize_singleton_dim_strides | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| class TestCanonicalizeSingletonDimStrides: | ||
| def test_flashinfer_layout_dim_neg3(self): | ||
| """FlashInfer path: degenerate stride at dim -3 (num_kv_heads).""" | ||
| # Shape after permute: [num_blocks, 2, num_kv_heads, block_size, head_size] | ||
| num_blocks, block_size, head_size = 64, 16, 128 | ||
| t = torch.zeros(num_blocks, 2, 1, block_size, head_size, dtype=torch.bfloat16) | ||
| t_deg = _inject_degenerate_stride(t, dim=-3) | ||
|
|
||
| assert t_deg.stride(-3) == 1 # confirm degenerate | ||
| assert t_deg.is_contiguous() # PyTorch doesn't notice | ||
|
|
||
| fixed = canonicalize_singleton_dim_strides(t_deg) | ||
|
|
||
| assert fixed.stride(-3) == block_size * head_size # canonical = 2048 | ||
| assert fixed.stride(-2) == head_size # inner dims unchanged | ||
| assert fixed.stride(-1) == 1 | ||
|
|
||
| def test_flash_attn_layout_dim_neg2(self): | ||
| """FlashAttention path: degenerate stride at dim -2 (num_kv_heads).""" | ||
| # Shape after unbind(0): [num_blocks, block_size, num_kv_heads, head_size] | ||
| num_blocks, block_size, head_size = 64, 16, 128 | ||
| t = torch.zeros(num_blocks, block_size, 1, head_size, dtype=torch.bfloat16) | ||
| t_deg = _inject_degenerate_stride(t, dim=-2) | ||
|
|
||
| assert t_deg.stride(-2) == 1 | ||
| assert t_deg.is_contiguous() | ||
|
|
||
| fixed = canonicalize_singleton_dim_strides(t_deg) | ||
|
|
||
| assert fixed.stride(-2) == head_size # canonical = 128 | ||
| assert fixed.stride(-1) == 1 | ||
|
|
||
| def test_canonical_strides_returned_as_is(self): | ||
| """No degenerate strides → same object returned (no copy, no new view).""" | ||
| t = torch.zeros(64, 2, 1, 16, 128, dtype=torch.bfloat16) | ||
| result = canonicalize_singleton_dim_strides(t) | ||
| assert result is t | ||
|
|
||
| def test_multi_kv_heads_unchanged(self): | ||
| """num_kv_heads > 1 → strides are already canonical → unchanged.""" | ||
| t = torch.zeros(16, 2, 4, 16, 128, dtype=torch.bfloat16) | ||
| original_strides = t.stride() | ||
| result = canonicalize_singleton_dim_strides(t) | ||
| assert result.stride() == original_strides | ||
|
|
||
| def test_data_pointer_preserved(self): | ||
| """Fix is zero-copy: same underlying storage.""" | ||
| t = torch.zeros(8, 2, 1, 16, 128, dtype=torch.bfloat16) | ||
| t_deg = _inject_degenerate_stride(t, dim=-3) | ||
| fixed = canonicalize_singleton_dim_strides(t_deg) | ||
| assert fixed.data_ptr() == t_deg.data_ptr() | ||
| assert fixed.storage_offset() == t_deg.storage_offset() | ||
|
|
||
| def test_multiple_singleton_dims(self): | ||
| """All size-1 dims with degenerate strides are fixed.""" | ||
| # Shape: [1, 1, 8, 32] — two size-1 dims | ||
| t = torch.zeros(1, 1, 8, 32, dtype=torch.float16) | ||
| # Both size-1 dims get degenerate strides | ||
| t_deg = t.as_strided(t.shape, (1, 1, 32, 1)) # both leading dims = 1 | ||
|
|
||
| fixed = canonicalize_singleton_dim_strides(t_deg) | ||
|
|
||
| assert fixed.stride(0) == 1 * 8 * 32 # canonical: 256 | ||
| assert fixed.stride(1) == 1 * 8 * 32 # canonical: 256 (same since size-1) | ||
| assert fixed.stride(2) == 32 | ||
| assert fixed.stride(3) == 1 | ||
|
|
||
| def test_various_shapes_flashinfer(self): | ||
| """Correctness across different block_size / head_size for FlashInfer layout.""" | ||
| for block_size, head_size in [(16, 64), (16, 128), (32, 128), (16, 256)]: | ||
| t = torch.zeros(8, 2, 1, block_size, head_size, dtype=torch.bfloat16) | ||
| t_deg = _inject_degenerate_stride(t, dim=-3) | ||
| fixed = canonicalize_singleton_dim_strides(t_deg) | ||
| assert fixed.stride(-3) == block_size * head_size, ( | ||
| f"Failed for block_size={block_size}, head_size={head_size}: " | ||
| f"got stride(-3)={fixed.stride(-3)}" | ||
| ) | ||
|
|
||
| def test_various_shapes_flash_attn(self): | ||
| """Correctness across different shapes for FlashAttention layout.""" | ||
| for block_size, head_size in [(16, 64), (16, 128), (32, 128)]: | ||
| t = torch.zeros(8, block_size, 1, head_size, dtype=torch.bfloat16) | ||
| t_deg = _inject_degenerate_stride(t, dim=-2) | ||
| fixed = canonicalize_singleton_dim_strides(t_deg) | ||
| assert fixed.stride(-2) == head_size, ( | ||
| f"Failed for block_size={block_size}, head_size={head_size}: " | ||
| f"got stride(-2)={fixed.stride(-2)}" | ||
| ) | ||
|
|
||
| def test_tma_alignment_satisfied_after_fix_bf16(self): | ||
| """After fix, all strides meet 16-byte TMA alignment for bf16.""" | ||
| t = torch.zeros(64, 2, 1, 16, 128, dtype=torch.bfloat16) | ||
| t_deg = _inject_degenerate_stride(t, dim=-3) | ||
| fixed = canonicalize_singleton_dim_strides(t_deg) | ||
|
|
||
| element_size = fixed.element_size() # 2 bytes for bf16 | ||
| for i, s in enumerate(fixed.stride()): | ||
| assert (s * element_size) % 16 == 0 or i == len(fixed.stride()) - 1, ( | ||
| f"dim {i} stride {s} * {element_size} bytes not 16-byte aligned" | ||
| ) | ||
|
|
||
| def test_non_contiguous_outer_dims_preserved(self): | ||
| """Outer (non-size-1) non-contiguous strides are left unchanged.""" | ||
| # Simulate cross-layer unified allocation: num_blocks stride is non-canonical | ||
| # but the inner dims should be fixed. | ||
| base = torch.zeros(200, 2, 1, 16, 128, dtype=torch.bfloat16) | ||
| # Slice every 2nd block → non-canonical outer stride | ||
| t_sliced = base[::2] # shape [100, 2, 1, 16, 128], stride[0] = 2*canonical | ||
| t_deg = _inject_degenerate_stride(t_sliced, dim=-3) | ||
|
|
||
| fixed = canonicalize_singleton_dim_strides(t_deg) | ||
|
|
||
| # Outer stride should be unchanged (not a size-1 dim) | ||
| assert fixed.stride(0) == t_sliced.stride(0) | ||
| # Inner degenerate stride should be fixed | ||
| assert fixed.stride(-3) == 16 * 128 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you smoke test with some non-trivial benchmark that this doesn't include overhead? with piece-wise CUDA graphs, this method is executed in eager-mode
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good call. Here are the results:
Test script: