diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py index 11c64d8cbc82..f0c7e9366f11 100644 --- a/vllm/utils/torch_utils.py +++ b/vllm/utils/torch_utils.py @@ -64,6 +64,36 @@ T = TypeVar("T") +def is_strictly_contiguous(t: torch.Tensor) -> bool: + """ + Check if tensor is contiguous AND has no degenerate strides. + + A degenerate stride occurs when a dimension has size 1 but the stride + doesn't match the canonical contiguous layout. This can cause issues + in some CUDA kernels that rely on stride values for memory access. + + For a C-contiguous tensor of shape (d0, d1, ..., dn), the expected + strides are: stride[i] = product(shape[i+1:]) for all i, with stride[-1]=1. + + Example with torch.Size([16, 1, 8, 32]): + - Canonical strides: (256, 256, 32, 1) + - Degenerate strides: (256, 1, 32, 1) # dim=1 has size=1, allowing + # non-canonical stride in dim=0 + """ + if not t.is_contiguous(): + return False + + # Check that strides match canonical contiguous layout + shape = t.shape + strides = t.stride() + expected_stride = 1 + for i in range(len(shape) - 1, -1, -1): + if strides[i] != expected_stride: + return False + expected_stride *= shape[i] + return True + + @contextlib.contextmanager def set_default_torch_dtype(dtype: torch.dtype): """Sets the default torch dtype to the given dtype.""" diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 8dc2838d88a5..314a8f2bb42d 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -40,6 +40,7 @@ ) from vllm.utils.math_utils import cdiv from vllm.utils.platform_utils import is_pin_memory_available +from vllm.utils.torch_utils import is_strictly_contiguous from vllm.v1.attention.backend import ( AttentionBackend, AttentionImpl, @@ -1392,11 +1393,11 @@ def forward( # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND assert get_kv_cache_layout() == "HND" - assert prefill_query.is_contiguous() - assert kv_cache_permute.is_contiguous() - assert workspace_buffer.is_contiguous() - assert block_tables_prefill.is_contiguous() - assert seq_lens_prefill.is_contiguous() + assert is_strictly_contiguous(prefill_query) + assert is_strictly_contiguous(kv_cache_permute) + assert is_strictly_contiguous(workspace_buffer) + assert is_strictly_contiguous(block_tables_prefill) + assert is_strictly_contiguous(seq_lens_prefill) if output.dtype == FP4_DTYPE: assert self.o_sf_scale is not None @@ -1503,11 +1504,11 @@ def forward( # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND assert get_kv_cache_layout() == "HND" - assert decode_query.is_contiguous() - assert kv_cache_permute.is_contiguous() - assert workspace_buffer.is_contiguous() - assert block_tables_decode.is_contiguous() - assert seq_lens_decode.is_contiguous() + assert is_strictly_contiguous(decode_query) + assert is_strictly_contiguous(kv_cache_permute) + assert is_strictly_contiguous(workspace_buffer) + assert is_strictly_contiguous(block_tables_decode) + assert is_strictly_contiguous(seq_lens_decode) if output.dtype == FP4_DTYPE: assert self.o_sf_scale is not None