Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 162 additions & 0 deletions tests/v1/attention/test_kv_head_stride_canonicalization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for canonicalize_singleton_dim_strides.

Background
----------
When num_kv_heads_per_rank == 1 (e.g. Qwen3.5-397B with TP=8 → 1 KV head
per rank), PyTorch's is_contiguous() returns True for *any* stride on the
size-1 dimension. The KV cache allocator can therefore produce a tensor
where that singleton dim has stride = 1 element (2 bytes for bf16) instead
of the canonical product-of-remaining-dims value.

CUDA TMA (used by FlashInfer XQA SM90 and Flash-Attention 3/4 on H100+)
requires all non-outermost strides to be multiples of 16 bytes. A 2-byte
stride triggers cudaErrorIllegalInstruction.

canonicalize_singleton_dim_strides() patches degenerate strides on all
size-1 dimensions via torch.as_strided — zero-copy.

The degenerate stride manifests at different positions in different backends:
- FlashInfer: stride(-3) after kv_cache.permute() → shape [..., 1, B, D]
- FlashAttention: stride(-2) after kv_cache.unbind(0) → shape [N, B, 1, D]
"""

import torch

from vllm.utils.torch_utils import canonicalize_singleton_dim_strides

# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------


def _inject_degenerate_stride(t: torch.Tensor, dim: int) -> torch.Tensor:
"""Return a view of t with a degenerate (stride=1) on a size-1 dim."""
assert t.shape[dim] == 1, f"dim {dim} must have size 1"
strides = list(t.stride())
strides[dim] = 1 # inject the bug
return t.as_strided(t.shape, strides)


# ---------------------------------------------------------------------------
# Tests: canonicalize_singleton_dim_strides
# ---------------------------------------------------------------------------


class TestCanonicalizeSingletonDimStrides:
def test_flashinfer_layout_dim_neg3(self):
"""FlashInfer path: degenerate stride at dim -3 (num_kv_heads)."""
# Shape after permute: [num_blocks, 2, num_kv_heads, block_size, head_size]
num_blocks, block_size, head_size = 64, 16, 128
t = torch.zeros(num_blocks, 2, 1, block_size, head_size, dtype=torch.bfloat16)
t_deg = _inject_degenerate_stride(t, dim=-3)

assert t_deg.stride(-3) == 1 # confirm degenerate
assert t_deg.is_contiguous() # PyTorch doesn't notice

fixed = canonicalize_singleton_dim_strides(t_deg)

assert fixed.stride(-3) == block_size * head_size # canonical = 2048
assert fixed.stride(-2) == head_size # inner dims unchanged
assert fixed.stride(-1) == 1

def test_flash_attn_layout_dim_neg2(self):
"""FlashAttention path: degenerate stride at dim -2 (num_kv_heads)."""
# Shape after unbind(0): [num_blocks, block_size, num_kv_heads, head_size]
num_blocks, block_size, head_size = 64, 16, 128
t = torch.zeros(num_blocks, block_size, 1, head_size, dtype=torch.bfloat16)
t_deg = _inject_degenerate_stride(t, dim=-2)

assert t_deg.stride(-2) == 1
assert t_deg.is_contiguous()

fixed = canonicalize_singleton_dim_strides(t_deg)

assert fixed.stride(-2) == head_size # canonical = 128
assert fixed.stride(-1) == 1

def test_canonical_strides_returned_as_is(self):
"""No degenerate strides → same object returned (no copy, no new view)."""
t = torch.zeros(64, 2, 1, 16, 128, dtype=torch.bfloat16)
result = canonicalize_singleton_dim_strides(t)
assert result is t

def test_multi_kv_heads_unchanged(self):
"""num_kv_heads > 1 → strides are already canonical → unchanged."""
t = torch.zeros(16, 2, 4, 16, 128, dtype=torch.bfloat16)
original_strides = t.stride()
result = canonicalize_singleton_dim_strides(t)
assert result.stride() == original_strides

def test_data_pointer_preserved(self):
"""Fix is zero-copy: same underlying storage."""
t = torch.zeros(8, 2, 1, 16, 128, dtype=torch.bfloat16)
t_deg = _inject_degenerate_stride(t, dim=-3)
fixed = canonicalize_singleton_dim_strides(t_deg)
assert fixed.data_ptr() == t_deg.data_ptr()
assert fixed.storage_offset() == t_deg.storage_offset()

def test_multiple_singleton_dims(self):
"""All size-1 dims with degenerate strides are fixed."""
# Shape: [1, 1, 8, 32] — two size-1 dims
t = torch.zeros(1, 1, 8, 32, dtype=torch.float16)
# Both size-1 dims get degenerate strides
t_deg = t.as_strided(t.shape, (1, 1, 32, 1)) # both leading dims = 1

fixed = canonicalize_singleton_dim_strides(t_deg)

assert fixed.stride(0) == 1 * 8 * 32 # canonical: 256
assert fixed.stride(1) == 1 * 8 * 32 # canonical: 256 (same since size-1)
assert fixed.stride(2) == 32
assert fixed.stride(3) == 1

def test_various_shapes_flashinfer(self):
"""Correctness across different block_size / head_size for FlashInfer layout."""
for block_size, head_size in [(16, 64), (16, 128), (32, 128), (16, 256)]:
t = torch.zeros(8, 2, 1, block_size, head_size, dtype=torch.bfloat16)
t_deg = _inject_degenerate_stride(t, dim=-3)
fixed = canonicalize_singleton_dim_strides(t_deg)
assert fixed.stride(-3) == block_size * head_size, (
f"Failed for block_size={block_size}, head_size={head_size}: "
f"got stride(-3)={fixed.stride(-3)}"
)

def test_various_shapes_flash_attn(self):
"""Correctness across different shapes for FlashAttention layout."""
for block_size, head_size in [(16, 64), (16, 128), (32, 128)]:
t = torch.zeros(8, block_size, 1, head_size, dtype=torch.bfloat16)
t_deg = _inject_degenerate_stride(t, dim=-2)
fixed = canonicalize_singleton_dim_strides(t_deg)
assert fixed.stride(-2) == head_size, (
f"Failed for block_size={block_size}, head_size={head_size}: "
f"got stride(-2)={fixed.stride(-2)}"
)

def test_tma_alignment_satisfied_after_fix_bf16(self):
"""After fix, all strides meet 16-byte TMA alignment for bf16."""
t = torch.zeros(64, 2, 1, 16, 128, dtype=torch.bfloat16)
t_deg = _inject_degenerate_stride(t, dim=-3)
fixed = canonicalize_singleton_dim_strides(t_deg)

element_size = fixed.element_size() # 2 bytes for bf16
for i, s in enumerate(fixed.stride()):
assert (s * element_size) % 16 == 0 or i == len(fixed.stride()) - 1, (
f"dim {i} stride {s} * {element_size} bytes not 16-byte aligned"
)

def test_non_contiguous_outer_dims_preserved(self):
"""Outer (non-size-1) non-contiguous strides are left unchanged."""
# Simulate cross-layer unified allocation: num_blocks stride is non-canonical
# but the inner dims should be fixed.
base = torch.zeros(200, 2, 1, 16, 128, dtype=torch.bfloat16)
# Slice every 2nd block → non-canonical outer stride
t_sliced = base[::2] # shape [100, 2, 1, 16, 128], stride[0] = 2*canonical
t_deg = _inject_degenerate_stride(t_sliced, dim=-3)

fixed = canonicalize_singleton_dim_strides(t_deg)

# Outer stride should be unchanged (not a size-1 dim)
assert fixed.stride(0) == t_sliced.stride(0)
# Inner degenerate stride should be fixed
assert fixed.stride(-3) == 16 * 128
42 changes: 42 additions & 0 deletions vllm/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,48 @@ def is_strictly_contiguous(t: torch.Tensor) -> bool:
return True


def canonicalize_singleton_dim_strides(t: torch.Tensor) -> torch.Tensor:
"""Canonicalize strides on size-1 dimensions for CUDA TMA compatibility.
Comment thread
vadiklyutiy marked this conversation as resolved.
Outdated

PyTorch's ``is_contiguous()`` returns ``True`` for *any* stride value on a
size-1 dimension, because a dimension of size 1 is never actually stepped
across. As a result, memory allocators may produce tensors where a size-1
dimension has ``stride = 1`` (one element) rather than the canonical
``product(shape[i+1:])``.

CUDA's TMA (Tensor Memory Accelerator), used by FlashInfer's XQA SM90
decode kernel and by Flash-Attention 3/4 on H100+, requires every
non-outermost stride to be a multiple of 16 bytes. For a bf16 tensor,
``stride = 1`` element means 2 bytes — well below the 16-byte minimum —
and triggers ``cudaErrorIllegalInstruction``.

This function uses ``torch.as_strided`` to patch size-1 dim strides to
their canonical C-contiguous value. **No data is copied**; only stride
metadata is updated. It is safe because a size-1 dimension is *never
stepped across* in pointer arithmetic — ``index * stride`` is always
``0 * stride = 0`` regardless of the stride value. Patching to the
canonical value therefore does not change any memory access; it only
satisfies TMA's alignment check.

Typical trigger: paged KV cache with ``num_kv_heads_per_rank == 1`` (e.g.
Qwen3.5-397B with ``--tensor-parallel-size 8``). When prefix-cached
blocks are freed and reallocated, the resulting view can have a degenerate
stride on the singleton head dimension.
"""
strides = list(t.stride())
shape = t.shape
s = 1
changed = False
for i in range(len(shape) - 1, -1, -1):
if shape[i] == 1 and strides[i] != s:
strides[i] = s
changed = True
s *= shape[i]
Comment thread
the-david-oy marked this conversation as resolved.
Outdated
if not changed:
return t
return t.as_strided(t.shape, strides)
Comment on lines +125 to +136

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you smoke test with some non-trivial benchmark that this doesn't include overhead? with piece-wise CUDA graphs, this method is executed in eager-mode

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call. Here are the results:

Overhead benchmark on H100 (PyTorch 2.6.0, NVIDIA container 24.12):
num_kv_heads > 1  (common — early exit)   234 ns/call
num_kv_heads = 1, stride already canonical  1060 ns/call
num_kv_heads = 1, degenerate stride (fix)  2921 ns/call

Test script:

import timeit
  import torch

  def canonicalize_singleton_dim_strides(t):
      if 1 not in t.shape:
          return t
      strides = list(t.stride())
      shape = t.shape
      prev_stride = 1
      changed = False
      for i in range(len(shape) - 1, -1, -1):
          if shape[i] == 1 and strides[i] != prev_stride:
              strides[i] = prev_stride
              changed = True
          prev_stride = strides[i] * shape[i]
      if not changed:
          return t
      return t.as_strided(t.shape, strides)

  N = 1_000_000

  # Common path: num_kv_heads=8 (no size=1 dims) -> early exit at "1 not in t.shape"
  t_common = torch.zeros(64, 2, 8, 16, 128, dtype=torch.bfloat16)
  elapsed = timeit.timeit(lambda: canonicalize_singleton_dim_strides(t_common), number=N)
  import timeit
  import torch
  import timeit
  import torch

  def canonicalize_singleton_dim_strides(t):
      if 1 not in t.shape:
          return t
      strides = list(t.stride())
      shape = t.shape
      prev_stride = 1
      changed = False
      for i in range(len(shape) - 1, -1, -1):
          if shape[i] == 1 and strides[i] != prev_stride:
              strides[i] = prev_stride
              changed = True
          prev_stride = strides[i] * shape[i]
      if not changed:
          return t
      return t.as_strided(t.shape, strides)

  N = 1_000_000

  # Common path: num_kv_heads=8 (no size=1 dims) -> early exit at "1 not in t.shape"
  t_common = torch.zeros(64, 2, 8, 16, 128, dtype=torch.bfloat16)
  elapsed = timeit.timeit(lambda: canonicalize_singleton_dim_strides(t_common), number=N)
  print(f"Common (no size=1):    {elapsed/N*1e9:.1f} ns/call")

  # num_kv_heads=1, stride already canonical -> loop runs, no change
  t_ok = torch.zeros(64, 2, 1, 16, 128, dtype=torch.bfloat16)
  elapsed = timeit.timeit(lambda: canonicalize_singleton_dim_strides(t_ok), number=N)
  print(f"size=1 canonical:      {elapsed/N*1e9:.1f} ns/call")

  # num_kv_heads=1 with degenerate stride (the bug case) -> fix applied
  t_deg = torch.as_strided(
      torch.zeros(64, 2, 1, 16, 128, dtype=torch.bfloat16),
      (64, 2, 1, 16, 128), (4096, 2048, 1, 128, 1)
  )
  elapsed = timeit.timeit(lambda: canonicalize_singleton_dim_strides(t_deg), number=N)
  print(f"Degenerate stride fix: {elapsed/N*1e9:.1f} ns/call")



@contextlib.contextmanager
def set_default_torch_dtype(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
Expand Down
31 changes: 27 additions & 4 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@

from vllm.model_executor.layers.attention import Attention
from vllm.platforms import current_platform
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.utils.torch_utils import (
canonicalize_singleton_dim_strides,
is_quantized_kv_cache,
)
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionImpl,
Expand Down Expand Up @@ -747,6 +750,23 @@ def forward(

# For decoder and cross-attention, use KV cache as before
key_cache, value_cache = kv_cache.unbind(0)
# Fix degenerate strides on size-1 dims (e.g. num_kv_heads=1 with TP).
# FA3/4 on H100+ uses TMA, which requires ≥16-byte stride alignment.
# See vllm.utils.torch_utils.canonicalize_singleton_dim_strides.
fixed_k = canonicalize_singleton_dim_strides(key_cache)
fixed_v = canonicalize_singleton_dim_strides(value_cache)
if fixed_k is not key_cache or fixed_v is not value_cache:
logger.debug(
"Canonicalized degenerate KV cache strides (FlashAttention): "
"shape=%s, key strides before=%s after=%s, "
"value strides before=%s after=%s",
key_cache.shape,
key_cache.stride(),
fixed_k.stride(),
value_cache.stride(),
fixed_v.stride(),
)
key_cache, value_cache = fixed_k, fixed_v

if is_quantized_kv_cache(self.kv_cache_dtype):
# queries are quantized in the attention layer
Expand Down Expand Up @@ -861,6 +881,8 @@ def do_kv_cache_update(
# we use direct Q, K, V tensors without caching
return

# Scatter write into the KV cache using slot_mapping indices.
# No TMA kernel is invoked here, so stride canonicalization is not needed.
key_cache, value_cache = kv_cache.unbind(0)

# Reshape the input keys and values and store them in the cache.
Expand Down Expand Up @@ -1156,9 +1178,10 @@ def cascade_attention(
) -> torch.Tensor:
assert alibi_slopes is None, "Cascade attention does not support ALiBi."
# TODO: Support sliding window.
assert sliding_window == (-1, -1), (
"Cascade attention does not support sliding window."
)
assert sliding_window == (
-1,
-1,
), "Cascade attention does not support sliding window."
Comment thread
vadiklyutiy marked this conversation as resolved.
Outdated

num_tokens = query.shape[0]
block_size = key_cache.shape[-3]
Expand Down
22 changes: 21 additions & 1 deletion vllm/v1/attention/backends/flash_attn_diffkv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@

import torch

from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.utils.torch_utils import (
canonicalize_singleton_dim_strides,
is_quantized_kv_cache,
)
from vllm.v1.attention.backend import AttentionType
from vllm.v1.attention.backends.fa_utils import is_flash_attn_varlen_func_available
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
Expand Down Expand Up @@ -190,6 +193,23 @@ def forward(
# Different head_size for K and V
key_cache = kv_cache[..., : self.head_size]
value_cache = kv_cache[..., self.head_size :]
# Fix degenerate strides on size-1 dims (e.g. num_kv_heads=1 with TP).
# FA3/4 on H100+ uses TMA, which requires ≥16-byte stride alignment.
# See vllm.utils.torch_utils.canonicalize_singleton_dim_strides.
fixed_k = canonicalize_singleton_dim_strides(key_cache)
fixed_v = canonicalize_singleton_dim_strides(value_cache)
if fixed_k is not key_cache or fixed_v is not value_cache:
logger.debug(
"Canonicalized degenerate KV cache strides (FlashAttentionDiffKV): "
"shape=%s, key strides before=%s after=%s, "
"value strides before=%s after=%s",
key_cache.shape,
key_cache.stride(),
fixed_k.stride(),
value_cache.stride(),
fixed_v.stride(),
)
key_cache, value_cache = fixed_k, fixed_v

if is_quantized_kv_cache(self.kv_cache_dtype):
# queries are quantized in the attention layer
Expand Down
Loading
Loading