Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,17 @@ def indexer_k_quant_and_cache_triton(
# In real layout, we store the first portion as kv cache value
# and second portion as kv cache scale
kv_cache = kv_cache.view(num_blocks, -1)
kv_cache_value = kv_cache[:, : block_size * head_dim]
# The K region must be viewed as the platform FP8 dtype so the kernel's
# `(val/scale).to(kv_cache_ptr.type.element_ty)` cast produces FP8 bit
# patterns. Without this view the slice keeps `kv_cache`'s uint8 dtype
# and the cast becomes integer truncation (e.g. 1.7 -> 1, -0.3 -> 255),
# writing arbitrary integers as if they were FP8 values. The downstream
# paged-MQA logits kernel reads those bytes as FP8 and produces garbage,
# which causes the DSv3.2 sparse-MLA + MTP top-K accuracy collapse for
# context_len > 2048.
kv_cache_value = kv_cache[:, : block_size * head_dim].view(
current_platform.fp8_dtype()
)
kv_cache_scale = kv_cache[:, block_size * head_dim :].view(torch.float32)
head_tile_size = head_tile_size // kv_cache.element_size()
grid = (num_tokens,)
Expand Down