Skip to content

[Bugfix] Fix degenerate KV cache stride causing TMA cudaErrorIllegalInstruction#40737

Merged
robertgshaw2-redhat merged 10 commits intovllm-project:mainfrom
the-david-oy:fix/tma-kv-head-stride-degenerate
May 3, 2026
Merged

[Bugfix] Fix degenerate KV cache stride causing TMA cudaErrorIllegalInstruction#40737
robertgshaw2-redhat merged 10 commits intovllm-project:mainfrom
the-david-oy:fix/tma-kv-head-stride-degenerate

Conversation

@the-david-oy
Copy link
Copy Markdown
Contributor

@the-david-oy the-david-oy commented Apr 23, 2026

Purpose

Fix cudaErrorIllegalInstruction → NCCL allgather hang → EngineDeadError that occurs when num_kv_heads_per_rank == 1 (e.g. Qwen3.5-397B with --tensor-parallel-size 8) and prefix caching is enabled.

Root cause: When prefix-cached KV blocks are freed and reallocated, PyTorch can produce tensor views with a degenerate stride on the singleton num_kv_heads dimension. is_contiguous() returns True for any stride on a size-1 dimension, so .contiguous() does not fix it. CUDA TMA (used by FlashInfer XQA SM90 and Flash-Attention 3/4 on H100+) requires all non-outermost strides to be multiples of 16 bytes; a bf16 stride=1 (2 bytes) violates this and faults.

This bug affects all three attention backends:

  • FlashInfer (flashinfer.py): degenerate stride at dim -3 of kv_cache_permute after permute(), shape [..., num_kv_heads=1, block_size, head_size]
  • FlashAttention (flash_attn.py): degenerate stride at dim -2 of key_cache after unbind(0), shape [num_blocks, block_size, num_kv_heads=1, head_size]
  • FlashAttentionDiffKV (flash_attn_diffkv.py): same dim -2 issue after the packed [..., head_size_k + head_size_v] last-dim slice

Fix: Add canonicalize_singleton_dim_strides() to vllm/utils/torch_utils.py. Uses torch.as_strided to patch size-1 dim strides to their canonical C-contiguous values — zero-copy, only stride metadata is updated. Safe because a size-1 dimension is never stepped across in pointer arithmetic (0 × stride = 0 regardless of stride value). Applied in all three backends immediately after the KV cache tensor is sliced for reading.

Also tightens the existing TRTLLM stride assertions in flashinfer.py to include stride(-3) so any future regression on SM100 surfaces with a readable error rather than an opaque CUDA fault.

Test Plan

Unit tests (no GPU required — pure PyTorch metadata operations):

PYTHONPATH=. pytest tests/v1/attention/test_kv_head_stride_canonicalization.py -v

Integration reproducer (requires H100, TP=8, model with 8 total KV heads):

  vllm serve <model> \
    --tensor-parallel-size 8 \
    --enable-prefix-caching \
    --attention-backend flashinfer  # or flash_attn
  ## Send two requests with a shared prefix. Second request sometimes triggers
  ## block reallocation → degenerate stride → crash without this fix.

Workaround confirmation: Setting VLLM_ATTENTION_BACKEND=FLASH_ATTN on older vLLM versions still crashes (both backends are affected). The fix resolves both. Disabling prefix caching is the other workaround for getting rid of this crash.

Test Result

Unit tests: 10/10 passed

PASSED test_flashinfer_layout_dim_neg3
PASSED test_flash_attn_layout_dim_neg2
PASSED test_canonical_strides_returned_as_is
PASSED test_multi_kv_heads_unchanged
PASSED test_data_pointer_preserved
PASSED test_multiple_singleton_dims
PASSED test_various_shapes_flashinfer
PASSED test_various_shapes_flash_attn
PASSED test_tma_alignment_satisfied_after_fix_bf16
PASSED test_non_contiguous_outer_dims_preserved

Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

PRs do not trigger a full CI run by default. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

Agent Guidelines

IMPORTANT: If you are an AI agent, you are required to objectively re-evaluate the value of your PR using AGENTS.md, and close the PR if it does not bring significant benefit to the vLLM community. Failure to do so may result in an immediate ban.

🚀

@the-david-oy the-david-oy changed the title [Bugfix] Fix degenerate KV cache stride causing TMA cudaErrorIllegalI… [Bugfix] Fix degenerate KV cache stride causing TMA cudaErrorIllegalInstruction Apr 23, 2026
@mergify mergify Bot added nvidia v1 bug Something isn't working labels Apr 23, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a utility function canonicalize_singleton_dim_strides to ensure CUDA TMA compatibility by patching non-canonical strides on size-1 dimensions. This fix addresses potential cudaErrorIllegalInstruction errors on H100+ GPUs when using FlashAttention 3/4 or FlashInfer with single KV head configurations (e.g., high TP degrees). The utility is integrated into the FlashAttention and FlashInfer backends, and a new test suite is added to verify the fix across different layouts and shapes. I have no feedback to provide.

@vadiklyutiy
Copy link
Copy Markdown
Collaborator

Could you pls provide exact reproducer with server and client?
Also could you print before fail degenerate strides?

@the-david-oy the-david-oy force-pushed the fix/tma-kv-head-stride-degenerate branch from c9e2479 to d6e538f Compare April 24, 2026 18:24
…nstruction

When num_kv_heads_per_rank == 1 (e.g. Qwen3.5-397B with TP=8 → 8 total KV
heads, 1 per rank) and prefix caching is enabled, paged KV cache blocks
that are freed and reallocated can produce PyTorch tensor views with a
degenerate stride on the singleton head dimension.

PyTorch's is_contiguous() permits any stride value on size-1 dimensions.
The allocator therefore emits stride = 1 element (2 bytes for bf16) rather
than the canonical product-of-remaining-dims value.

CUDA TMA (Tensor Memory Accelerator), used by FlashInfer's XQA SM90 decode
kernel and by Flash-Attention 3/4 on H100+, requires all non-outermost
strides to be multiples of 16 bytes. A 2-byte stride violates this and
produces cudaErrorIllegalInstruction on the affected TP rank. That rank then
hangs in the TP allgather; after 600s the watchdog fires and the engine
shuts down with EngineDeadError.

This bug affects all three attention backends:
- FlashInfer (flashinfer.py): degenerate stride at dim -3 of kv_cache_permute
  after permute(), shape [..., num_kv_heads=1, block_size, head_size]
- FlashAttention (flash_attn.py): degenerate stride at dim -2 of key_cache
  after unbind(0), shape [num_blocks, block_size, num_kv_heads=1, head_size]
- FlashAttentionDiffKV (flash_attn_diffkv.py): same dim -2 issue after the
  packed [..., head_size_k + head_size_v] last-dim slice

Fix: add canonicalize_singleton_dim_strides() to vllm/utils/torch_utils.py.
This uses torch.as_strided to patch size-1 dim strides to their canonical
C-contiguous values — zero-copy, only stride metadata is updated. A size-1
dimension is never stepped across in pointer arithmetic (0 * stride = 0
regardless of stride), so patching to canonical does not change any memory
access and is unconditionally safe. Apply in all three backends immediately
after the KV cache tensor is sliced.

Also tighten the existing TRTLLM stride assertions in flashinfer.py to
include stride(-3) so any future regression on SM100 surfaces with a
readable error rather than an opaque CUDA fault.

Reproducer: model with 8 total KV heads, TP=8, prefix caching enabled,
FlashInfer or FlashAttention backend (H100). Second request causes prefix-
cached blocks to be freed and reallocated → degenerate stride → crash.

Ref: flashinfer-ai/flashinfer#2232

Co-authored-by: Claude <claude@anthropic.com>
Signed-off-by: David Oy <david@baseten.co>
@the-david-oy the-david-oy force-pushed the fix/tma-kv-head-stride-degenerate branch from d6e538f to a25ecac Compare April 24, 2026 18:25
@the-david-oy
Copy link
Copy Markdown
Contributor Author

the-david-oy commented Apr 24, 2026

Thanks for reviewing so quickly, Vadim! I added the debug statements.

We ran into this with production traffic, so I'm working on getting and testing a reproducer with a server and client.

@the-david-oy
Copy link
Copy Markdown
Contributor Author

the-david-oy commented Apr 25, 2026

@vadiklyutiy Thank for your patience! Please see the results of my investigation below. TLDR is that given this error happens intermittently in production without the fix, I wasn't able to create a reliable server/client repro. I provided an injection repro, plus explained why it's necessary and safe to add as a guard.

Reproduction and Verification

Background

We observed production crashes on Qwen3.5-397B with --tensor-parallel-size 8 and --enable-prefix-caching on 8× H100. The symptom is cudaErrorIllegalInstruction on one or more TP ranks after a request that reuses prefix-cached KV blocks, followed by a NCCL hang and eventual EngineDeadError. Disabling --enable-prefix-caching eliminates the crash. This is the same scenario reported in flashinfer-ai/flashinfer#2232 (which you actually commented on 😄).


Root cause

With --tensor-parallel-size 8 and 8 total KV heads, each rank holds 1 KV head. After kv_cache.unbind(0) (FA3 backend) or kv_cache.permute(*stride_order) (FlashInfer backend), the resulting tensor has shape (..., 1, head_size) — a size-1 dimension for num_kv_heads.

PyTorch's contiguity rules allow any stride value for a size-1 dimension, because index × stride = 0 × stride = 0 regardless — the stride is never used in pointer arithmetic. Consequently, tensor.is_contiguous() returns True even when the size-1 dim carries stride=1 (1 element = 2 bytes for bf16) rather than the canonical head_size value. .contiguous() therefore returns self unchanged and cannot fix this.

CUDA's TMA (Tensor Memory Accelerator), used by Flash-Attention 3/4 and FlashInfer on H100/H200, validates all non-outermost, non-innermost strides during cuTensorMapEncodeTiled. The minimum aligned stride is 16 bytes. A stride of 2 bytes (1 bf16 element) triggers:

Error: Failed to initialize the TMA descriptor
globalStrides  (2, 128, 2, 2048, 0)
                ^
                └─ 1 element × 2 bytes = 2 bytes < 16-byte TMA minimum

On kernel paths that fall back gracefully (FA3 on H100) the server continues running but silently uses a slower code path. On paths without a fallback, this becomes cudaErrorIllegalInstruction.


Why we cannot reproduce without injection

I traced all standard KV cache allocation paths in vLLM 0.19.1 and none naturally produces stride=1 for the num_kv_heads dimension:

Path num_kv_heads stride TMA-safe?
torch.zeros(2, N, block_size, 1, head_size) head_size elements
flat buffer → .view(...) head_size elements
_update_hybrid_attention_mamba_layout (as_strided_) preserves head_size
cross-layer permute + slice head_size × some factor

The production crash is real and --enable-prefix-caching is the reliable trigger, but I have not isolated the exact allocator state that yields stride=1 reliably. It may be specific to the production vLLM/PyTorch/CUDA build, a code path I did not trace, or a subtle race in block eviction that reuses storage in a way that introduces the degenerate stride.


Injection-based demonstration (H100, vLLM 0.19.1)

To show the mechanism concretely, I patched flash_attn.py to inject stride=1 for any size-1 dimension immediately after kv_cache.unbind(0) — precisely the stride value PyTorch is permitted to produce — and ran a standard server/client.

Setup: TinyLlama-1.1B-Chat (4 KV heads) at --tensor-parallel-size 4 gives 1 KV head per rank, the same ratio as Qwen3.5-397B at TP=8.

Step 1 — patch flash_attn.py to inject the degenerate stride.

Apply the following after line 697 of vllm/v1/attention/backends/flash_attn.py
(right after key_cache, value_cache = kv_cache.unbind(0)):

# ---- CRASH INJECTION (remove before merging) --------------------------------
import sys as _sys
_modified = False
for _c in (key_cache, value_cache):
    _sh, _st = _c.shape, list(_c.stride())
    for _i, _sz in enumerate(_sh[:-1]):   # skip innermost (element stride)
        if _sz == 1 and _i > 0:           # skip outermost dim too
            _st[_i] = 1                   # degenerate: 1 elem = 2 bytes < 16-byte TMA min
            _modified = True
    if _modified:
        _c.set_(_c.storage(), storage_offset=_c.storage_offset(),
                size=_sh, stride=tuple(_st))
if _modified:
    print(f"[CRASH_DEMO] degenerate strides: key={key_cache.stride()} "
          f"shape={key_cache.shape}", file=_sys.stderr, flush=True)
# ---- END INJECTION ----------------------------------------------------------

Step 2 — start the server (any 4× H100 node):

vllm serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 \
    --tensor-parallel-size 4 \
    --enable-prefix-caching \
    --max-model-len 2048 \
    --enforce-eager

Step 3 — run the client:

from openai import OpenAI
client = OpenAI(base_url="http://localhost:8000/v1", api_key="x")
PREFIX = "The quick brown fox jumps over the lazy dog. " * 50  # ~450 tokens
for i in range(3):
    r = client.chat.completions.create(
        model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
        messages=[{"role": "user", "content": PREFIX + "\n\nWhat animal jumps?"}],
        max_tokens=32, temperature=0,
    )
    print(r.choices[0].message.content)

Observed server stderr (all 4 TP ranks, interleaved):

[CRASH_DEMO] Injected degenerate strides: key=(1024, 64, 1, 1) shape=torch.Size([357864, 16, 1, 64])
Error: Failed to initialize the TMA descriptor 1
TMA Desc Addr:   0x7ffea364f600
globalStrides  (2,128,2,2048,0)
Error: Failed to initialize the TMA descriptor 1
TMA Desc Addr:   0x7ffed12d9500
globalStrides  (2,128,2,2048,0)
...  (repeated for all 4 ranks)

The 2 in globalStrides is the degenerate stride in bytes: 1 element × 2 bytes bf16 = 2 bytes. FA3 has a fallback so requests complete on H100; on kernel paths without a fallback this terminates the worker.

Server with this PR's fix applied — same client, zero TMA errors:

[FIX_DEMO] Before fix: key=(1024, 64,  1, 1)   ← degenerate: 1 elem = 2 bytes
[FIX_DEMO] After fix:  key=(1024, 64, 64, 1)   ← canonical:  64 elem = 128 bytes ✓

Why the fix is correct and safe

canonicalize_singleton_dim_strides (in vllm/utils/torch_utils.py) uses torch.as_strided — a metadata-only operation — to set the stride of any size-1 dimension to its canonical product(shape[i+1:]) value:

def canonicalize_singleton_dim_strides(t: torch.Tensor) -> torch.Tensor:
    strides = list(t.stride())
    shape = t.shape
    s = 1
    changed = False
    for i in range(len(shape) - 1, -1, -1):
        if shape[i] == 1 and strides[i] != s:
            strides[i] = s
            changed = True
        s *= shape[i]
    if not changed:
        return t          # ← zero overhead on the common path (num_kv_heads > 1)
    return t.as_strided(t.shape, strides)

Safety: For a size-1 dimension, every valid index is 0, so index × stride = 0 always. The stride is multiplied by zero and never contributes to any memory address. Changing it to the canonical value is provably a no-op with respect to memory access.

Performance: When num_kv_heads > 1 (the common case), no stride is degenerate and the function returns self immediately — no allocation, no copy, no overhead.

Correctness of 3-way fix: The same guard is applied in all three affected backends (flash_attn.py:756, flashinfer.py:1454, flash_attn_diffkv.py:199) so the fix is consistent regardless of which backend is selected.

@vadiklyutiy
Copy link
Copy Markdown
Collaborator

This is the same scenario reported in flashinfer-ai/flashinfer#2232 (which you actually commented on 😄).

Yeah, I do remember that. Also #32417 and #32008 are related

@vadiklyutiy
Copy link
Copy Markdown
Collaborator

This is the same scenario reported in flashinfer-ai/flashinfer#2232 (which you actually commented on 😄).

Yeah, I do remember that. Also #32417 and #32008 are related

And this is the reason I pay attention to the PR. Something that was fixed, not fully fixed.

Comment thread vllm/utils/torch_utils.py Outdated
Comment thread vllm/v1/attention/backends/flash_attn.py Outdated
Comment thread vllm/utils/torch_utils.py Outdated
Co-authored-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com>
Signed-off-by: David Oy <58150256+the-david-oy@users.noreply.github.com>
@the-david-oy the-david-oy force-pushed the fix/tma-kv-head-stride-degenerate branch from 89916a9 to f4142ec Compare April 28, 2026 15:37
…rmat

Signed-off-by: David Oy <david.oy@baseten.co>
@the-david-oy the-david-oy force-pushed the fix/tma-kv-head-stride-degenerate branch from f4142ec to d1863eb Compare April 28, 2026 15:38
Copy link
Copy Markdown
Collaborator

@vadiklyutiy vadiklyutiy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also change 2 places with reshape(t.shape) from #32417 to use canonicalize_singleton_dim_strides()?

Comment thread vllm/v1/attention/backends/flashinfer.py Outdated
Comment thread vllm/v1/attention/backends/flashinfer.py Outdated
…es for query tensors

prefill_query and decode_query in the TRTLLMGen paths used
.contiguous().reshape(t.shape) to fix degenerate strides on size=1
dimensions. Replace with the purpose-built helper for clarity and
consistency with the kv_cache fix applied elsewhere in this PR.

Signed-off-by: David Oy <david.oy@baseten.co>
…ton_dim_strides

The manual kv_strides / expected_head_stride assertion blocks in the
TRTLLMGen prefill and decode paths are replaced with explicit calls to
canonicalize_singleton_dim_strides, per reviewer feedback. This removes
the hard assert (which produces an opaque Python exception on stride
mismatches) in favour of the fix-and-continue approach already used
earlier in the same dispatch path.

Signed-off-by: David Oy <david.oy@baseten.co>
@the-david-oy
Copy link
Copy Markdown
Contributor Author

@pavanimajety I saw you reviewed the two related PRs Vadim linked above. Would you be able to take a look at this one as well?

@the-david-oy
Copy link
Copy Markdown
Contributor Author

@vadiklyutiy Thanks for reviewing above! Apologies, I did not realize you were the code owner and a maintainer.

@vadiklyutiy
Copy link
Copy Markdown
Collaborator

@vadiklyutiy Thanks for reviewing above! Apologies, I did not realize you were the code owner and a maintainer.

how does it impact? :)

@vadiklyutiy vadiklyutiy added the verified Run pre-commit for new contributors without triggering other tests label Apr 30, 2026
Comment thread vllm/v1/attention/backends/flashinfer.py Outdated
Comment thread vllm/v1/attention/backends/flashinfer.py Outdated
@github-project-automation github-project-automation Bot moved this to In review in NVIDIA Apr 30, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 30, 2026

Hi @dyastremsky, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@the-david-oy
Copy link
Copy Markdown
Contributor Author

@vadiklyutiy Thanks for reviewing above! Apologies, I did not realize you were the code owner and a maintainer.

how does it impact? :)

That's why I tagged Pavani. I thought we needed the code owner's s review here, but we had it. 🙏

…ments

Reviewer preference: write contiguous() and canonicalize_singleton_dim_strides
as separate assignments so the execution order (contiguous first) is visually
unambiguous.

Signed-off-by: David Oy <david.oy@baseten.co>
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 30, 2026

Hi @dyastremsky, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 30, 2026

Hi @dyastremsky, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

flash_attn_diffkv.py used logger.debug() without defining logger;
add init_logger(__name__) after all imports.

cpu_resource_utils.py: os.sched_getaffinity is Linux-only; mypy stubs
exclude it on non-Linux targets.  Add type: ignore[attr-defined] — the
Darwin early-return above already guards the runtime call.

Signed-off-by: David Oy <david.oy@baseten.co>
@mergify mergify Bot added the cpu Related to CPU backends label Apr 30, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 30, 2026

Hi @dyastremsky, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@robertgshaw2-redhat robertgshaw2-redhat added the ready ONLY add when PR is ready to merge/full CI is needed label May 1, 2026
Comment thread vllm/utils/torch_utils.py Outdated
Comment on lines +114 to +140
"""Canonicalize strides on dimensions of size=1 for CUDA TMA compatibility.

PyTorch's ``is_contiguous()`` returns ``True`` for *any* stride value on a
dimension of size=1, because a dimension of size=1 is never actually stepped
across. As a result, memory allocators may produce tensors where a size=1
dimension has ``stride = 1`` (one element) rather than the canonical
``product(shape[i+1:])``.

CUDA's TMA (Tensor Memory Accelerator), used by FlashInfer's XQA SM90
decode kernel and by Flash-Attention 3/4 on H100+, requires every
non-outermost stride to be a multiple of 16 bytes. For a bf16 tensor,
``stride = 1`` element means 2 bytes — well below the 16-byte minimum —
and triggers ``cudaErrorIllegalInstruction``.

This function uses ``torch.as_strided`` to patch size=1 dim strides to
their canonical value. **No data is copied**; only stride
metadata is updated. It is safe because a dimension of size=1 is *never
stepped across* in pointer arithmetic — ``index * stride`` is always
``0 * stride = 0`` regardless of the stride value. Patching to the
canonical value therefore does not change any memory access; it only
satisfies TMA's alignment check.

Typical trigger: paged KV cache with ``num_kv_heads_per_rank == 1`` (e.g.
Qwen3.5-397B with ``--tensor-parallel-size 8``). When prefix-cached
blocks are freed and reallocated, the resulting view can have a degenerate
stride on the singleton head dimension.
"""
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a huge docstring that seems like slop, can this be cleaned up?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! Let me know if the updated comment is concise enough.

Comment on lines -1560 to +1578
# The inner (block_size, head_size) dims must be
# contiguous; outer dims may have non-canonical strides
# (e.g. cross-layer unified allocation).
# Degenerate strides on outer dims break TMA descriptors
# (see flashinfer-ai/flashinfer#2232).
kv_strides = kv_cache_permute.stride()
assert (
kv_strides[-1] == 1
and kv_strides[-2] == kv_cache_permute.shape[-1]
), (
"KV cache inner dims (block_size, head_size) must be "
f"contiguous, got strides {kv_strides}"
kv_cache_permute = canonicalize_singleton_dim_strides(
kv_cache_permute
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new function only patches size-1 dim strides, so I don't see why we can remove the assertion since we aren't catching when the inner dims are not contiguous

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Restored the inner-dim assertion in the latest commit. Thanks for calling this out, I agree that canonicalize only patches size=1 dims and shouldn't remove that guard.

Comment thread vllm/utils/torch_utils.py
Comment on lines +141 to +152
strides = list(t.stride())
shape = t.shape
prev_stride = 1
changed = False
for i in range(len(shape) - 1, -1, -1):
if shape[i] == 1 and strides[i] != prev_stride:
strides[i] = prev_stride
changed = True
prev_stride = strides[i] * shape[i]
if not changed:
return t
return t.as_strided(t.shape, strides)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you smoke test with some non-trivial benchmark that this doesn't include overhead? with piece-wise CUDA graphs, this method is executed in eager-mode

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call. Here are the results:

Overhead benchmark on H100 (PyTorch 2.6.0, NVIDIA container 24.12):
num_kv_heads > 1  (common — early exit)   234 ns/call
num_kv_heads = 1, stride already canonical  1060 ns/call
num_kv_heads = 1, degenerate stride (fix)  2921 ns/call

Test script:

import timeit
  import torch

  def canonicalize_singleton_dim_strides(t):
      if 1 not in t.shape:
          return t
      strides = list(t.stride())
      shape = t.shape
      prev_stride = 1
      changed = False
      for i in range(len(shape) - 1, -1, -1):
          if shape[i] == 1 and strides[i] != prev_stride:
              strides[i] = prev_stride
              changed = True
          prev_stride = strides[i] * shape[i]
      if not changed:
          return t
      return t.as_strided(t.shape, strides)

  N = 1_000_000

  # Common path: num_kv_heads=8 (no size=1 dims) -> early exit at "1 not in t.shape"
  t_common = torch.zeros(64, 2, 8, 16, 128, dtype=torch.bfloat16)
  elapsed = timeit.timeit(lambda: canonicalize_singleton_dim_strides(t_common), number=N)
  import timeit
  import torch
  import timeit
  import torch

  def canonicalize_singleton_dim_strides(t):
      if 1 not in t.shape:
          return t
      strides = list(t.stride())
      shape = t.shape
      prev_stride = 1
      changed = False
      for i in range(len(shape) - 1, -1, -1):
          if shape[i] == 1 and strides[i] != prev_stride:
              strides[i] = prev_stride
              changed = True
          prev_stride = strides[i] * shape[i]
      if not changed:
          return t
      return t.as_strided(t.shape, strides)

  N = 1_000_000

  # Common path: num_kv_heads=8 (no size=1 dims) -> early exit at "1 not in t.shape"
  t_common = torch.zeros(64, 2, 8, 16, 128, dtype=torch.bfloat16)
  elapsed = timeit.timeit(lambda: canonicalize_singleton_dim_strides(t_common), number=N)
  print(f"Common (no size=1):    {elapsed/N*1e9:.1f} ns/call")

  # num_kv_heads=1, stride already canonical -> loop runs, no change
  t_ok = torch.zeros(64, 2, 1, 16, 128, dtype=torch.bfloat16)
  elapsed = timeit.timeit(lambda: canonicalize_singleton_dim_strides(t_ok), number=N)
  print(f"size=1 canonical:      {elapsed/N*1e9:.1f} ns/call")

  # num_kv_heads=1 with degenerate stride (the bug case) -> fix applied
  t_deg = torch.as_strided(
      torch.zeros(64, 2, 1, 16, 128, dtype=torch.bfloat16),
      (64, 2, 1, 16, 128), (4096, 2048, 1, 128, 1)
  )
  elapsed = timeit.timeit(lambda: canonicalize_singleton_dim_strides(t_deg), number=N)
  print(f"Degenerate stride fix: {elapsed/N*1e9:.1f} ns/call")

…fast path

- canonicalize_singleton_dim_strides docstring cut to 3 sentences
- Restore kv_strides[-1]/[-2] assertion after canonicalize in both
  TRTLLMGen paths: canonicalize only fixes size=1 dims; the assertion
  guards block_size and head_size which are never size=1 but should
  always be canonical (regression detection)
- Add early exit "if 1 not in t.shape: return t" to skip list allocation
  and loop for the common case (num_kv_heads > 1); torch.Size.__contains__
  is C-level

Signed-off-by: David Oy <david.oy@baseten.co>
@github-project-automation github-project-automation Bot moved this from In review to Ready in NVIDIA May 3, 2026
@robertgshaw2-redhat robertgshaw2-redhat enabled auto-merge (squash) May 3, 2026 22:16
@robertgshaw2-redhat robertgshaw2-redhat merged commit 66dfee7 into vllm-project:main May 3, 2026
63 checks passed
@github-project-automation github-project-automation Bot moved this from Ready to Done in NVIDIA May 3, 2026
joa-stdn pushed a commit to joa-stdn/vllm that referenced this pull request May 4, 2026
…nstruction (vllm-project#40737)

Signed-off-by: David Oy <david@baseten.co>
Signed-off-by: David Oy <58150256+the-david-oy@users.noreply.github.com>
Signed-off-by: David Oy <david.oy@baseten.co>
Co-authored-by: David Oy <david@baseten.co>
Co-authored-by: Claude <claude@anthropic.com>
Co-authored-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com>
Signed-off-by: Joachim Studnia <joachim@mistral.ai>
chaojun-zhang pushed a commit to chaojun-zhang/vllm that referenced this pull request May 6, 2026
…nstruction (vllm-project#40737)

Signed-off-by: David Oy <david@baseten.co>
Signed-off-by: David Oy <58150256+the-david-oy@users.noreply.github.com>
Signed-off-by: David Oy <david.oy@baseten.co>
Co-authored-by: David Oy <david@baseten.co>
Co-authored-by: Claude <claude@anthropic.com>
Co-authored-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com>
Copilot AI pushed a commit to hongbolv/vllm that referenced this pull request May 7, 2026
…nstruction (vllm-project#40737)

Signed-off-by: David Oy <david@baseten.co>
Signed-off-by: David Oy <58150256+the-david-oy@users.noreply.github.com>
Signed-off-by: David Oy <david.oy@baseten.co>
Co-authored-by: David Oy <david@baseten.co>
Co-authored-by: Claude <claude@anthropic.com>
Co-authored-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com>
Co-authored-by: hongbolv <33214277+hongbolv@users.noreply.github.com>
ikaadil pushed a commit to ikaadil/vllm that referenced this pull request May 7, 2026
…nstruction (vllm-project#40737)

Signed-off-by: David Oy <david@baseten.co>
Signed-off-by: David Oy <58150256+the-david-oy@users.noreply.github.com>
Signed-off-by: David Oy <david.oy@baseten.co>
Co-authored-by: David Oy <david@baseten.co>
Co-authored-by: Claude <claude@anthropic.com>
Co-authored-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com>
Signed-off-by: Ifta Khairul Alam Adil <ikaadil007@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working cpu Related to CPU backends nvidia ready ONLY add when PR is ready to merge/full CI is needed v1 verified Run pre-commit for new contributors without triggering other tests

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants