Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,7 @@ def __init__(
# try to use fp8 q if kv cache is fp8, and will fall back to model dtype
# if TRTLLM attention kernel is not used when building attn metadata
can_use_trtllm = can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads)

if (
can_use_trtllm
and not vllm_config.attention_config.disable_flashinfer_q_quantization
Expand Down Expand Up @@ -1436,7 +1437,6 @@ def forward(
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
assert get_kv_cache_layout() == "HND"
assert is_strictly_contiguous(prefill_query)
assert is_strictly_contiguous(kv_cache_permute)
assert is_strictly_contiguous(workspace_buffer)
assert is_strictly_contiguous(block_tables_prefill)
assert is_strictly_contiguous(seq_lens_prefill)
Expand All @@ -1461,6 +1461,20 @@ def forward(
# and fp8 kv cache. So to enable prefill attention
# with fp8 kv cache, we can construct a mock block
# and mock kv cache with BF16 KV involved in the prefill
#
# The inner (block_size, head_size) dims must be
# contiguous; outer dims may have non-canonical strides
# (e.g. cross-layer unified allocation).
# Degenerate strides on outer dims break TMA descriptors
# (see flashinfer-ai/flashinfer#2232).
kv_strides = kv_cache_permute.stride()
assert (
kv_strides[-1] == 1
and kv_strides[-2] == kv_cache_permute.shape[-1]
), (
"KV cache inner dims (block_size, head_size) must be "
f"contiguous, got strides {kv_strides}"
)
mock_kv_cache, mock_block_table = trtllm_prefill_attn_kvfp8_dequant(
kv_cache_permute,
block_tables_prefill,
Expand Down Expand Up @@ -1549,10 +1563,21 @@ def forward(
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
assert get_kv_cache_layout() == "HND"
assert is_strictly_contiguous(decode_query)
assert is_strictly_contiguous(kv_cache_permute)
assert is_strictly_contiguous(workspace_buffer)
assert is_strictly_contiguous(block_tables_decode)
assert is_strictly_contiguous(seq_lens_decode)
# kv_cache outer dims may be non-contiguous (e.g.
# cross-layer unified allocation), but inner dims
# (block_size, head_size) must be contiguous and
# strides must be canonical to avoid TMA descriptor
# failures (see flashinfer-ai/flashinfer#2232).
kv_strides = kv_cache_permute.stride()
assert (
kv_strides[-1] == 1 and kv_strides[-2] == kv_cache_permute.shape[-1]
), (
"KV cache inner dims (block_size, head_size) must be "
f"contiguous, got strides {kv_strides}"
)

if output.dtype == FP4_DTYPE:
assert self.o_sf_scale is not None
Expand Down
Loading