Skip to content

[ROCm][DSv3.2] Fix FP8 cast in indexer_k_quant_and_cache_triton (top-K accuracy regression for ctx>2048)#1

Closed
maeehart wants to merge 1 commit into
frida-andersson:rocm/gluon-preshuffle-indexerfrom
maeehart:mahartik/fix-indexer-fp8-cast
Closed

[ROCm][DSv3.2] Fix FP8 cast in indexer_k_quant_and_cache_triton (top-K accuracy regression for ctx>2048)#1
maeehart wants to merge 1 commit into
frida-andersson:rocm/gluon-preshuffle-indexerfrom
maeehart:mahartik/fix-indexer-fp8-cast

Conversation

@maeehart
Copy link
Copy Markdown

Summary

Fixes the DeepSeek-V3.2 sparse-MLA + MTP top-K accuracy collapse for context_len > 2048 on AMD MI355X by correcting an FP8-cast bug in indexer_k_quant_and_cache_triton.

Root cause

The Triton kernel _indexer_k_quant_and_cache_kernel writes quantized values via:

fp8_val = (val.to(tl.float32) / scale).to(kv_cache_ptr.type.element_ty)

The Python wrapper indexer_k_quant_and_cache_triton previously sliced the uint8-typed kv_cache to extract the K region without re-viewing it as FP8:

kv_cache_value = kv_cache[:, : block_size * head_dim]   # uint8 dtype

so the kernel saw kv_cache_ptr.type.element_ty == tl.uint8. Triton's <float32>.to(tl.uint8) is integer truncation (e.g. 1.7 -> 1, -0.3 -> 255 wrap-around) — not FP8 quantization. The kernel wrote arbitrary integer bytes into the cache, and the downstream _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle kernel then read those bytes as FP8 e4m3fnuz, producing garbled K values.

For context_len <= 2048 the noise happened to preserve top-K rank often enough that the regression was hard to spot, but beyond 2048 the dominant logits no longer aligned with the reference and top-K parity collapsed (~0% match), which is exactly the MTP draft-acceptance cliff observed in production.

The commit message of ab4b563 (which introduced the SHUFFLE indexer) already noted a 3pp GSM8K regression "under investigation" — this is that regression's root cause.

The fix

-    kv_cache_value = kv_cache[:, : block_size * head_dim]
+    kv_cache_value = kv_cache[:, : block_size * head_dim].view(
+        current_platform.fp8_dtype()
+    )

current_platform.fp8_dtype() returns torch.float8_e4m3fnuz on AMD ROCm and torch.float8_e4m3fn on NVIDIA CUDA, matching the kernel's IS_FNUZ constant. The scale region is unchanged — it was already viewed as torch.float32 on the next line.

Validation

Byte-equivalence vs proven-correct path

Same input K (bf16), block=64, head_dim=128:

Block BUGGY (uint8) vs C+++shuffle_weight FIXED (fp8 view) vs C+++shuffle_weight
0 8138 / 8192 differ (99.3%) 0 / 8192 differ
1 8146 / 8192 differ (99.4%) 0 / 8192 differ
2 8141 / 8192 differ (99.4%) 0 / 8192 differ

The FIXED path produces byte-identical output to aiter.ops.cache.indexer_k_quant_and_cache followed by aiter.ops.shuffle.shuffle_weight — the exact path used by AITER's official benchmark bench_deepgemm_attention.py.

End-to-end logits parity (MI355X, single GPU)

Setup: heads=64, hd=128, block=64, Preshuffle=True, KVBlockSize=64, ChunkK=256, ue8m0 scales, single batch, single next_n. Reference: bf16-derived torch ground truth with the same per-128-group ue8m0 quantization the kernel uses.

ctx BUGGY topk@2048 FIXED topk@2048 FIXED max_rel
128 1.000 1.000 1.0e-3
512 1.000 1.000 8.8e-3
1024 1.000 1.000 7.7e-3
2048 1.000 1.000 3.4e-2
2049 0.814 1.000 1.3e-2
4096 0.002 1.000 3.3e-2
8192 0.002 1.000 1.5e-2
16384 0.005 1.000 3.6e-1

The 2049 boundary is the SplitKV/multi-block transition where the BUGGY path's accidental top-K agreement breaks down — exactly matching the user-observed MTP regression threshold.

Production-wrapper validation

Same harness, but mounting the patched rocm_aiter_mla_sparse.py over the original via docker -v and calling indexer_k_quant_and_cache_triton(...) (not the kernel directly):

ctx=  128  PATCHED_WRAPPER: bad/128=     0  topk=1.0000  max_rel=1.030e-03
ctx=  512  PATCHED_WRAPPER: bad/512=     0  topk=1.0000  max_rel=8.782e-03
ctx= 1024  PATCHED_WRAPPER: bad/1024=     0  topk=1.0000  max_rel=7.737e-03
ctx= 2048  PATCHED_WRAPPER: bad/2048=     0  topk=1.0000  max_rel=3.392e-02
ctx= 2049  PATCHED_WRAPPER: bad/2049=     0  topk=1.0000  max_rel=1.272e-02
ctx= 4096  PATCHED_WRAPPER: bad/4096=     0  topk=1.0000  max_rel=3.305e-02
ctx= 8192  PATCHED_WRAPPER: bad/8192=     0  topk=1.0000  max_rel=1.465e-02
ctx=16384  PATCHED_WRAPPER: bad/16384=     1  topk=1.0000  max_rel=3.641e-01

Cross-LLM verification

Verified the bug diagnosis and the fix with a 3-of-3 LLM swarm (gpt-5.1 / Claude-Haiku-4.5 / GPT-oss-20B over the AMD LLM gateway) — unanimous TRUE / HIGH confidence on both.

Why is this triton path used in production at all?

DSv3.2 sparse-MLA on MI355X populates the indexer KV cache via indexer_k_quant_and_cache_triton (selected over the C++ producer in the ROCm path). All non-kernel-side workarounds (e.g. an externally applied aiter.ops.shuffle.shuffle_weight after the producer) are equivalent in correctness but require either an extra pass over the K bytes per token or a global flag that becomes inconsistent on subsequent token writes. Fixing the producer is the cheapest and most local correctness fix; the kernel itself is unchanged.

Test plan

  • Run scripts/test_e2e_v3.py — BUGGY vs FIXED vs CPP+SHUFFLE end-to-end logits parity for ctx ∈ {128, 512, 1024, 2048, 2049, 4096, 8192, 16384}
  • Run scripts/test_byte_equiv.py — byte equivalence with C++ producer + shuffle_weight
  • Run scripts/test_patched_wrapper.py — same harness using the patched production wrapper via Docker volume mount
  • Run upstream vllm/tests/kernels/attention/test_deepgemm_attention.py once on a CUDA host (the test does not exercise this code path directly, but should still pass)
  • DSv3.2 + MTP end-to-end serving on MI355X (1K input / 100 output, TP4) — confirm GSM8K is back to baseline (~0.9424)

Related

Signed-off-by: Marko Hartikainen marko.hartikainen@amd.com

The Triton kernel `_indexer_k_quant_and_cache_kernel` casts the quantized
value to the storage dtype via:

    fp8_val = (val.to(tl.float32) / scale).to(kv_cache_ptr.type.element_ty)

The wrapper `indexer_k_quant_and_cache_triton` previously sliced the
uint8-typed `kv_cache` to obtain the K region without re-viewing the
slice as FP8:

    kv_cache_value = kv_cache[:, : block_size * head_dim]   # uint8 dtype

so the kernel's `kv_cache_ptr.type.element_ty` resolved to `tl.uint8` and
the cast became integer truncation (e.g. 1.7 -> 1, -0.3 -> 255 wrap-around)
rather than FP8 e4m3fnuz quantization. The kernel wrote arbitrary integer
bytes that the downstream `_gluon_deepgemm_fp8_paged_mqa_logits_preshuffle`
kernel then read as FP8, producing garbage logits.

For context_len <= 2048 the kernel's SplitKV scheduling and the relu+sum
of weighted dot products coincidentally preserved top-K rank, but for
context_len > 2048 top-K parity collapsed (~0% match vs reference) and
DSv3.2 + MTP draft acceptance dropped sharply on MI355X.

The commit message of ab4b563 (which introduced the SHUFFLE indexer)
already noted a 3pp GSM8K regression "under investigation"; this is that
regression's root cause.

Fix: view the K-region slice as the platform FP8 dtype before passing
it to the JIT kernel. The scale region is unchanged (already viewed as
torch.float32 on the next line). After the fix the wrapper produces
byte-identical output to the proven-correct path
(`aiter.ops.cache.indexer_k_quant_and_cache` + `aiter.ops.shuffle.shuffle_weight`)
which is what AITER's own benchmark `bench_deepgemm_attention.py` uses.

Verification on MI355X (heads=64, hd=128, block=64, Preshuffle=True,
ChunkK=256, ue8m0 scales) - kernel logits vs torch reference:

  ctx     BUGGY              FIXED              max_rel(FIXED)
   128   topk=1.000   ->    topk=1.000          1.0e-3
   512   topk=1.000   ->    topk=1.000          8.8e-3
  1024   topk=1.000   ->    topk=1.000          7.7e-3
  2048   topk=1.000   ->    topk=1.000          3.4e-2
  2049   topk=0.814   ->    topk=1.000          1.3e-2
  4096   topk=0.002   ->    topk=1.000          3.3e-2
  8192   topk=0.002   ->    topk=1.000          1.5e-2
 16384   topk=0.005   ->    topk=1.000          3.6e-1

Verified end-to-end through the production wrapper
`indexer_k_quant_and_cache_triton` mounted over the original source -
all measured contexts up to 16384 reach top-K=1.000 vs the bf16-derived
torch reference.

Cross-verified with a 3-of-3 LLM swarm (gpt-5.1, Claude-Haiku-4.5,
GPT-oss-20B) - unanimous TRUE / HIGH confidence on the bug diagnosis
and the fix.

Signed-off-by: Marko Hartikainen <marko.hartikainen@amd.com>
@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant