diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index 47fb46496226..fd5328d8ca14 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -128,7 +128,17 @@ def indexer_k_quant_and_cache_triton( # In real layout, we store the first portion as kv cache value # and second portion as kv cache scale kv_cache = kv_cache.view(num_blocks, -1) - kv_cache_value = kv_cache[:, : block_size * head_dim] + # The K region must be viewed as the platform FP8 dtype so the kernel's + # `(val/scale).to(kv_cache_ptr.type.element_ty)` cast produces FP8 bit + # patterns. Without this view the slice keeps `kv_cache`'s uint8 dtype + # and the cast becomes integer truncation (e.g. 1.7 -> 1, -0.3 -> 255), + # writing arbitrary integers as if they were FP8 values. The downstream + # paged-MQA logits kernel reads those bytes as FP8 and produces garbage, + # which causes the DSv3.2 sparse-MLA + MTP top-K accuracy collapse for + # context_len > 2048. + kv_cache_value = kv_cache[:, : block_size * head_dim].view( + current_platform.fp8_dtype() + ) kv_cache_scale = kv_cache[:, block_size * head_dim :].view(torch.float32) head_tile_size = head_tile_size // kv_cache.element_size() grid = (num_tokens,)