Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 162 additions & 0 deletions tests/v1/attention/test_kv_head_stride_canonicalization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for canonicalize_singleton_dim_strides.

Background
----------
When num_kv_heads_per_rank == 1 (e.g. Qwen3.5-397B with TP=8 → 1 KV head
per rank), PyTorch's is_contiguous() returns True for *any* stride on the
size-1 dimension. The KV cache allocator can therefore produce a tensor
where that singleton dim has stride = 1 element (2 bytes for bf16) instead
of the canonical product-of-remaining-dims value.

CUDA TMA (used by FlashInfer XQA SM90 and Flash-Attention 3/4 on H100+)
requires all non-outermost strides to be multiples of 16 bytes. A 2-byte
stride triggers cudaErrorIllegalInstruction.

canonicalize_singleton_dim_strides() patches degenerate strides on all
size-1 dimensions via torch.as_strided — zero-copy.

The degenerate stride manifests at different positions in different backends:
- FlashInfer: stride(-3) after kv_cache.permute() → shape [..., 1, B, D]
- FlashAttention: stride(-2) after kv_cache.unbind(0) → shape [N, B, 1, D]
"""

import torch

from vllm.utils.torch_utils import canonicalize_singleton_dim_strides

# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------


def _inject_degenerate_stride(t: torch.Tensor, dim: int) -> torch.Tensor:
"""Return a view of t with a degenerate (stride=1) on a size-1 dim."""
assert t.shape[dim] == 1, f"dim {dim} must have size 1"
strides = list(t.stride())
strides[dim] = 1 # inject the bug
return t.as_strided(t.shape, strides)


# ---------------------------------------------------------------------------
# Tests: canonicalize_singleton_dim_strides
# ---------------------------------------------------------------------------


class TestCanonicalizeSingletonDimStrides:
def test_flashinfer_layout_dim_neg3(self):
"""FlashInfer path: degenerate stride at dim -3 (num_kv_heads)."""
# Shape after permute: [num_blocks, 2, num_kv_heads, block_size, head_size]
num_blocks, block_size, head_size = 64, 16, 128
t = torch.zeros(num_blocks, 2, 1, block_size, head_size, dtype=torch.bfloat16)
t_deg = _inject_degenerate_stride(t, dim=-3)

assert t_deg.stride(-3) == 1 # confirm degenerate
assert t_deg.is_contiguous() # PyTorch doesn't notice

fixed = canonicalize_singleton_dim_strides(t_deg)

assert fixed.stride(-3) == block_size * head_size # canonical = 2048
assert fixed.stride(-2) == head_size # inner dims unchanged
assert fixed.stride(-1) == 1

def test_flash_attn_layout_dim_neg2(self):
"""FlashAttention path: degenerate stride at dim -2 (num_kv_heads)."""
# Shape after unbind(0): [num_blocks, block_size, num_kv_heads, head_size]
num_blocks, block_size, head_size = 64, 16, 128
t = torch.zeros(num_blocks, block_size, 1, head_size, dtype=torch.bfloat16)
t_deg = _inject_degenerate_stride(t, dim=-2)

assert t_deg.stride(-2) == 1
assert t_deg.is_contiguous()

fixed = canonicalize_singleton_dim_strides(t_deg)

assert fixed.stride(-2) == head_size # canonical = 128
assert fixed.stride(-1) == 1

def test_canonical_strides_returned_as_is(self):
"""No degenerate strides → same object returned (no copy, no new view)."""
t = torch.zeros(64, 2, 1, 16, 128, dtype=torch.bfloat16)
result = canonicalize_singleton_dim_strides(t)
assert result is t

def test_multi_kv_heads_unchanged(self):
"""num_kv_heads > 1 → strides are already canonical → unchanged."""
t = torch.zeros(16, 2, 4, 16, 128, dtype=torch.bfloat16)
original_strides = t.stride()
result = canonicalize_singleton_dim_strides(t)
assert result.stride() == original_strides

def test_data_pointer_preserved(self):
"""Fix is zero-copy: same underlying storage."""
t = torch.zeros(8, 2, 1, 16, 128, dtype=torch.bfloat16)
t_deg = _inject_degenerate_stride(t, dim=-3)
fixed = canonicalize_singleton_dim_strides(t_deg)
assert fixed.data_ptr() == t_deg.data_ptr()
assert fixed.storage_offset() == t_deg.storage_offset()

def test_multiple_singleton_dims(self):
"""All size-1 dims with degenerate strides are fixed."""
# Shape: [1, 1, 8, 32] — two size-1 dims
t = torch.zeros(1, 1, 8, 32, dtype=torch.float16)
# Both size-1 dims get degenerate strides
t_deg = t.as_strided(t.shape, (1, 1, 32, 1)) # both leading dims = 1

fixed = canonicalize_singleton_dim_strides(t_deg)

assert fixed.stride(0) == 1 * 8 * 32 # canonical: 256
assert fixed.stride(1) == 1 * 8 * 32 # canonical: 256 (same since size-1)
assert fixed.stride(2) == 32
assert fixed.stride(3) == 1

def test_various_shapes_flashinfer(self):
"""Correctness across different block_size / head_size for FlashInfer layout."""
for block_size, head_size in [(16, 64), (16, 128), (32, 128), (16, 256)]:
t = torch.zeros(8, 2, 1, block_size, head_size, dtype=torch.bfloat16)
t_deg = _inject_degenerate_stride(t, dim=-3)
fixed = canonicalize_singleton_dim_strides(t_deg)
assert fixed.stride(-3) == block_size * head_size, (
f"Failed for block_size={block_size}, head_size={head_size}: "
f"got stride(-3)={fixed.stride(-3)}"
)

def test_various_shapes_flash_attn(self):
"""Correctness across different shapes for FlashAttention layout."""
for block_size, head_size in [(16, 64), (16, 128), (32, 128)]:
t = torch.zeros(8, block_size, 1, head_size, dtype=torch.bfloat16)
t_deg = _inject_degenerate_stride(t, dim=-2)
fixed = canonicalize_singleton_dim_strides(t_deg)
assert fixed.stride(-2) == head_size, (
f"Failed for block_size={block_size}, head_size={head_size}: "
f"got stride(-2)={fixed.stride(-2)}"
)

def test_tma_alignment_satisfied_after_fix_bf16(self):
"""After fix, all strides meet 16-byte TMA alignment for bf16."""
t = torch.zeros(64, 2, 1, 16, 128, dtype=torch.bfloat16)
t_deg = _inject_degenerate_stride(t, dim=-3)
fixed = canonicalize_singleton_dim_strides(t_deg)

element_size = fixed.element_size() # 2 bytes for bf16
for i, s in enumerate(fixed.stride()):
assert (s * element_size) % 16 == 0 or i == len(fixed.stride()) - 1, (
f"dim {i} stride {s} * {element_size} bytes not 16-byte aligned"
)

def test_non_contiguous_outer_dims_preserved(self):
"""Outer (non-size-1) non-contiguous strides are left unchanged."""
# Simulate cross-layer unified allocation: num_blocks stride is non-canonical
# but the inner dims should be fixed.
base = torch.zeros(200, 2, 1, 16, 128, dtype=torch.bfloat16)
# Slice every 2nd block → non-canonical outer stride
t_sliced = base[::2] # shape [100, 2, 1, 16, 128], stride[0] = 2*canonical
t_deg = _inject_degenerate_stride(t_sliced, dim=-3)

fixed = canonicalize_singleton_dim_strides(t_deg)

# Outer stride should be unchanged (not a size-1 dim)
assert fixed.stride(0) == t_sliced.stride(0)
# Inner degenerate stride should be fixed
assert fixed.stride(-3) == 16 * 128
2 changes: 1 addition & 1 deletion vllm/utils/cpu_resource_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def get_allowed_cpu_list() -> list[LogicalCPUInfo]:
if platform.system() == "Darwin":
return cpu_list

global_allowed_cpu_id_list = os.sched_getaffinity(0)
global_allowed_cpu_id_list = os.sched_getaffinity(0) # type: ignore[attr-defined]
logical_cpu_list = [x for x in cpu_list if x.id in global_allowed_cpu_id_list]

return logical_cpu_list
Expand Down
26 changes: 26 additions & 0 deletions vllm/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,32 @@ def is_strictly_contiguous(t: torch.Tensor) -> bool:
return True


def canonicalize_singleton_dim_strides(t: torch.Tensor) -> torch.Tensor:
"""Fix degenerate strides on size=1 dimensions for CUDA TMA compatibility.

PyTorch allows any stride on a size=1 dim (is_contiguous() is always True
there), so a size=1 dim may have stride=1 (2 bytes for bf16) instead of
the canonical product(shape[i+1:]). CUDA TMA on H100+ requires all
non-outermost strides to be ≥16-byte aligned; stride=1 triggers
cudaErrorIllegalInstruction. Zero-copy: patches stride metadata only via
as_strided; returns t unchanged if all size=1 strides are already canonical.
"""
if 1 not in t.shape:
return t
strides = list(t.stride())
shape = t.shape
prev_stride = 1
changed = False
for i in range(len(shape) - 1, -1, -1):
if shape[i] == 1 and strides[i] != prev_stride:
strides[i] = prev_stride
changed = True
prev_stride = strides[i] * shape[i]
if not changed:
return t
return t.as_strided(t.shape, strides)
Comment on lines +125 to +136
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you smoke test with some non-trivial benchmark that this doesn't include overhead? with piece-wise CUDA graphs, this method is executed in eager-mode

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call. Here are the results:

Overhead benchmark on H100 (PyTorch 2.6.0, NVIDIA container 24.12):
num_kv_heads > 1  (common — early exit)   234 ns/call
num_kv_heads = 1, stride already canonical  1060 ns/call
num_kv_heads = 1, degenerate stride (fix)  2921 ns/call

Test script:

import timeit
  import torch

  def canonicalize_singleton_dim_strides(t):
      if 1 not in t.shape:
          return t
      strides = list(t.stride())
      shape = t.shape
      prev_stride = 1
      changed = False
      for i in range(len(shape) - 1, -1, -1):
          if shape[i] == 1 and strides[i] != prev_stride:
              strides[i] = prev_stride
              changed = True
          prev_stride = strides[i] * shape[i]
      if not changed:
          return t
      return t.as_strided(t.shape, strides)

  N = 1_000_000

  # Common path: num_kv_heads=8 (no size=1 dims) -> early exit at "1 not in t.shape"
  t_common = torch.zeros(64, 2, 8, 16, 128, dtype=torch.bfloat16)
  elapsed = timeit.timeit(lambda: canonicalize_singleton_dim_strides(t_common), number=N)
  import timeit
  import torch
  import timeit
  import torch

  def canonicalize_singleton_dim_strides(t):
      if 1 not in t.shape:
          return t
      strides = list(t.stride())
      shape = t.shape
      prev_stride = 1
      changed = False
      for i in range(len(shape) - 1, -1, -1):
          if shape[i] == 1 and strides[i] != prev_stride:
              strides[i] = prev_stride
              changed = True
          prev_stride = strides[i] * shape[i]
      if not changed:
          return t
      return t.as_strided(t.shape, strides)

  N = 1_000_000

  # Common path: num_kv_heads=8 (no size=1 dims) -> early exit at "1 not in t.shape"
  t_common = torch.zeros(64, 2, 8, 16, 128, dtype=torch.bfloat16)
  elapsed = timeit.timeit(lambda: canonicalize_singleton_dim_strides(t_common), number=N)
  print(f"Common (no size=1):    {elapsed/N*1e9:.1f} ns/call")

  # num_kv_heads=1, stride already canonical -> loop runs, no change
  t_ok = torch.zeros(64, 2, 1, 16, 128, dtype=torch.bfloat16)
  elapsed = timeit.timeit(lambda: canonicalize_singleton_dim_strides(t_ok), number=N)
  print(f"size=1 canonical:      {elapsed/N*1e9:.1f} ns/call")

  # num_kv_heads=1 with degenerate stride (the bug case) -> fix applied
  t_deg = torch.as_strided(
      torch.zeros(64, 2, 1, 16, 128, dtype=torch.bfloat16),
      (64, 2, 1, 16, 128), (4096, 2048, 1, 128, 1)
  )
  elapsed = timeit.timeit(lambda: canonicalize_singleton_dim_strides(t_deg), number=N)
  print(f"Degenerate stride fix: {elapsed/N*1e9:.1f} ns/call")



@contextlib.contextmanager
def set_default_torch_dtype(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
Expand Down
24 changes: 23 additions & 1 deletion vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@

from vllm.model_executor.layers.attention import Attention
from vllm.platforms import current_platform
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.utils.torch_utils import (
canonicalize_singleton_dim_strides,
is_quantized_kv_cache,
)
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionImpl,
Expand Down Expand Up @@ -747,6 +750,23 @@ def forward(

# For decoder and cross-attention, use KV cache as before
key_cache, value_cache = kv_cache.unbind(0)
# Fix degenerate strides on size-1 dims (e.g. num_kv_heads=1 with TP).
# FA3/4 on H100+ uses TMA, which requires ≥16-byte stride alignment.
# See vllm.utils.torch_utils.canonicalize_singleton_dim_strides.
fixed_k = canonicalize_singleton_dim_strides(key_cache)
fixed_v = canonicalize_singleton_dim_strides(value_cache)
if fixed_k is not key_cache or fixed_v is not value_cache:
logger.debug(
"Canonicalized degenerate KV cache strides (FlashAttention): "
"shape=%s, key strides before=%s after=%s, "
"value strides before=%s after=%s",
key_cache.shape,
key_cache.stride(),
fixed_k.stride(),
value_cache.stride(),
fixed_v.stride(),
)
key_cache, value_cache = fixed_k, fixed_v

if is_quantized_kv_cache(self.kv_cache_dtype):
# queries are quantized in the attention layer
Expand Down Expand Up @@ -861,6 +881,8 @@ def do_kv_cache_update(
# we use direct Q, K, V tensors without caching
return

# Scatter write into the KV cache using slot_mapping indices.
# No TMA kernel is invoked here, so stride canonicalization is not needed.
key_cache, value_cache = kv_cache.unbind(0)

# Reshape the input keys and values and store them in the cache.
Expand Down
25 changes: 24 additions & 1 deletion vllm/v1/attention/backends/flash_attn_diffkv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@

import torch

from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.logger import init_logger
from vllm.utils.torch_utils import (
canonicalize_singleton_dim_strides,
is_quantized_kv_cache,
)
from vllm.v1.attention.backend import AttentionType
from vllm.v1.attention.backends.fa_utils import (
get_flash_attn_version,
Expand All @@ -25,6 +29,8 @@
cascade_attention,
)

logger = init_logger(__name__)


class FlashAttentionDiffKVBackend(FlashAttentionBackend):
# Default to 128 for this backend
Expand Down Expand Up @@ -204,6 +210,23 @@ def forward(
# Different head_size for K and V
key_cache = kv_cache[..., : self.head_size]
value_cache = kv_cache[..., self.head_size :]
# Fix degenerate strides on size-1 dims (e.g. num_kv_heads=1 with TP).
# FA3/4 on H100+ uses TMA, which requires ≥16-byte stride alignment.
# See vllm.utils.torch_utils.canonicalize_singleton_dim_strides.
fixed_k = canonicalize_singleton_dim_strides(key_cache)
fixed_v = canonicalize_singleton_dim_strides(value_cache)
if fixed_k is not key_cache or fixed_v is not value_cache:
logger.debug(
"Canonicalized degenerate KV cache strides (FlashAttentionDiffKV): "
"shape=%s, key strides before=%s after=%s, "
"value strides before=%s after=%s",
key_cache.shape,
key_cache.stride(),
fixed_k.stride(),
value_cache.stride(),
fixed_v.stride(),
)
key_cache, value_cache = fixed_k, fixed_v

if is_quantized_kv_cache(self.kv_cache_dtype):
# queries are quantized in the attention layer
Expand Down
50 changes: 31 additions & 19 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.utils.torch_utils import (
canonicalize_singleton_dim_strides,
is_quantized_kv_cache,
is_strictly_contiguous,
nvfp4_kv_cache_full_dim,
Expand Down Expand Up @@ -1479,6 +1480,21 @@ def forward(

stride_order = FlashInferBackend.get_kv_cache_stride_order()
kv_cache_permute = kv_cache.permute(*stride_order) # HND and contiguous
# Fix degenerate strides on any size-1 dimension (e.g. num_kv_heads=1
# with TP=8). PyTorch permits non-canonical strides on size-1 dims;
# CUDA TMA requires ≥16-byte alignment on all non-outermost strides.
# canonicalize_singleton_dim_strides patches metadata via as_strided —
# zero-copy. See vllm.utils.torch_utils.
fixed = canonicalize_singleton_dim_strides(kv_cache_permute)
if fixed is not kv_cache_permute:
logger.debug(
"Canonicalized degenerate KV cache strides (FlashInfer): "
"shape=%s, strides before=%s, strides after=%s",
kv_cache_permute.shape,
kv_cache_permute.stride(),
fixed.stride(),
)
kv_cache_permute = fixed

# For NVFP4, the kv_cache last dim is full_dim (data + scale packed).
# Split into correctly-strided data and scale views.
Expand Down Expand Up @@ -1568,10 +1584,11 @@ def forward(
else:
assert isinstance(attn_metadata.prefill, TRTLLMPrefill)
# prefill_query may be non-contiguous or have degenerate strides
# First ensure memory contiguity, then fix degenerate strides
# with reshape. contiguous() alone doesn't fix degenerate
# strides when a dimension has size 1.
prefill_query = prefill_query.contiguous().reshape(prefill_query.shape)
# on size=1 dims. contiguous() ensures memory layout; then
# canonicalize_singleton_dim_strides fixes any remaining
# degenerate strides on size=1 dims for TMA alignment.
prefill_query = prefill_query.contiguous()
prefill_query = canonicalize_singleton_dim_strides(prefill_query)
workspace_buffer = _get_trtllm_gen_workspace_buffer()
block_tables_prefill = attn_metadata.prefill.block_tables
seq_lens_prefill = attn_metadata.prefill.seq_lens
Expand Down Expand Up @@ -1621,11 +1638,9 @@ def forward(
# with fp8 kv cache, we can construct a mock block
# and mock kv cache with BF16 KV involved in the prefill
#
# The inner (block_size, head_size) dims must be
# contiguous; outer dims may have non-canonical strides
# (e.g. cross-layer unified allocation).
# Degenerate strides on outer dims break TMA descriptors
# (see flashinfer-ai/flashinfer#2232).
kv_cache_permute = canonicalize_singleton_dim_strides(
kv_cache_permute
)
kv_strides = kv_cache_permute.stride()
assert (
kv_strides[-1] == 1
Expand Down Expand Up @@ -1732,12 +1747,13 @@ def forward(
if needs_fp8_out:
output[:num_decode_tokens].copy_(out_decode.to(output.dtype))
else:
# decode_query may be non-contiguous or have degenerate strides
assert isinstance(attn_metadata.decode, TRTLLMDecode)
# First ensure memory contiguity, then fix degenerate strides
# with reshape. contiguous() alone doesn't fix degenerate
# strides when a dimension has size 1.
decode_query = decode_query.contiguous().reshape(decode_query.shape)
# decode_query may be non-contiguous or have degenerate strides
# on size=1 dims. contiguous() ensures memory layout; then
# canonicalize_singleton_dim_strides fixes any remaining
# degenerate strides on size=1 dims for TMA alignment.
decode_query = decode_query.contiguous()
decode_query = canonicalize_singleton_dim_strides(decode_query)
workspace_buffer = _get_trtllm_gen_workspace_buffer()
block_tables_decode = attn_metadata.decode.block_tables
seq_lens_decode = attn_metadata.decode.seq_lens
Expand All @@ -1748,11 +1764,7 @@ def forward(
assert is_strictly_contiguous(workspace_buffer)
assert is_strictly_contiguous(block_tables_decode)
assert is_strictly_contiguous(seq_lens_decode)
# kv_cache outer dims may be non-contiguous (e.g.
# cross-layer unified allocation), but inner dims
# (block_size, head_size) must be contiguous and
# strides must be canonical to avoid TMA descriptor
# failures (see flashinfer-ai/flashinfer#2232).
kv_cache_permute = canonicalize_singleton_dim_strides(kv_cache_permute)
kv_strides = kv_cache_permute.stride()
assert (
kv_strides[-1] == 1 and kv_strides[-2] == kv_cache_permute.shape[-1]
Expand Down
Loading