From 831952464a205f36f7de48c646e3e070262ade7c Mon Sep 17 00:00:00 2001 From: Marko Hartikainen Date: Wed, 29 Apr 2026 20:29:09 +0300 Subject: [PATCH] [ROCm][DSv3.2] Fix FP8 cast in indexer_k_quant_and_cache_triton The Triton kernel `_indexer_k_quant_and_cache_kernel` casts the quantized value to the storage dtype via: fp8_val = (val.to(tl.float32) / scale).to(kv_cache_ptr.type.element_ty) The wrapper `indexer_k_quant_and_cache_triton` previously sliced the uint8-typed `kv_cache` to obtain the K region without re-viewing the slice as FP8: kv_cache_value = kv_cache[:, : block_size * head_dim] # uint8 dtype so the kernel's `kv_cache_ptr.type.element_ty` resolved to `tl.uint8` and the cast became integer truncation (e.g. 1.7 -> 1, -0.3 -> 255 wrap-around) rather than FP8 e4m3fnuz quantization. The kernel wrote arbitrary integer bytes that the downstream `_gluon_deepgemm_fp8_paged_mqa_logits_preshuffle` kernel then read as FP8, producing garbage logits. For context_len <= 2048 the kernel's SplitKV scheduling and the relu+sum of weighted dot products coincidentally preserved top-K rank, but for context_len > 2048 top-K parity collapsed (~0% match vs reference) and DSv3.2 + MTP draft acceptance dropped sharply on MI355X. The commit message of ab4b563 (which introduced the SHUFFLE indexer) already noted a 3pp GSM8K regression "under investigation"; this is that regression's root cause. Fix: view the K-region slice as the platform FP8 dtype before passing it to the JIT kernel. The scale region is unchanged (already viewed as torch.float32 on the next line). After the fix the wrapper produces byte-identical output to the proven-correct path (`aiter.ops.cache.indexer_k_quant_and_cache` + `aiter.ops.shuffle.shuffle_weight`) which is what AITER's own benchmark `bench_deepgemm_attention.py` uses. Verification on MI355X (heads=64, hd=128, block=64, Preshuffle=True, ChunkK=256, ue8m0 scales) - kernel logits vs torch reference: ctx BUGGY FIXED max_rel(FIXED) 128 topk=1.000 -> topk=1.000 1.0e-3 512 topk=1.000 -> topk=1.000 8.8e-3 1024 topk=1.000 -> topk=1.000 7.7e-3 2048 topk=1.000 -> topk=1.000 3.4e-2 2049 topk=0.814 -> topk=1.000 1.3e-2 4096 topk=0.002 -> topk=1.000 3.3e-2 8192 topk=0.002 -> topk=1.000 1.5e-2 16384 topk=0.005 -> topk=1.000 3.6e-1 Verified end-to-end through the production wrapper `indexer_k_quant_and_cache_triton` mounted over the original source - all measured contexts up to 16384 reach top-K=1.000 vs the bf16-derived torch reference. Cross-verified with a 3-of-3 LLM swarm (gpt-5.1, Claude-Haiku-4.5, GPT-oss-20B) - unanimous TRUE / HIGH confidence on the bug diagnosis and the fix. Signed-off-by: Marko Hartikainen --- vllm/v1/attention/ops/rocm_aiter_mla_sparse.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index 47fb46496226..fd5328d8ca14 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -128,7 +128,17 @@ def indexer_k_quant_and_cache_triton( # In real layout, we store the first portion as kv cache value # and second portion as kv cache scale kv_cache = kv_cache.view(num_blocks, -1) - kv_cache_value = kv_cache[:, : block_size * head_dim] + # The K region must be viewed as the platform FP8 dtype so the kernel's + # `(val/scale).to(kv_cache_ptr.type.element_ty)` cast produces FP8 bit + # patterns. Without this view the slice keeps `kv_cache`'s uint8 dtype + # and the cast becomes integer truncation (e.g. 1.7 -> 1, -0.3 -> 255), + # writing arbitrary integers as if they were FP8 values. The downstream + # paged-MQA logits kernel reads those bytes as FP8 and produces garbage, + # which causes the DSv3.2 sparse-MLA + MTP top-K accuracy collapse for + # context_len > 2048. + kv_cache_value = kv_cache[:, : block_size * head_dim].view( + current_platform.fp8_dtype() + ) kv_cache_scale = kv_cache[:, block_size * head_dim :].view(torch.float32) head_tile_size = head_tile_size // kv_cache.element_size() grid = (num_tokens,)