[ROCm][DSv3.2] Fix FP8 cast in indexer_k_quant_and_cache_triton (top-K accuracy regression for ctx>2048)#2
Closed
maeehart wants to merge 1 commit into
Conversation
The Triton kernel `_indexer_k_quant_and_cache_kernel` casts the quantized
value to the storage dtype via:
fp8_val = (val.to(tl.float32) / scale).to(kv_cache_ptr.type.element_ty)
The wrapper `indexer_k_quant_and_cache_triton` previously sliced the
uint8-typed `kv_cache` to obtain the K region without re-viewing the
slice as FP8:
kv_cache_value = kv_cache[:, : block_size * head_dim] # uint8 dtype
so the kernel's `kv_cache_ptr.type.element_ty` resolved to `tl.uint8` and
the cast became integer truncation (e.g. 1.7 -> 1, -0.3 -> 255 wrap-around)
rather than FP8 e4m3fnuz quantization. The kernel wrote arbitrary integer
bytes that the downstream `_gluon_deepgemm_fp8_paged_mqa_logits_preshuffle`
kernel then read as FP8, producing garbage logits.
For context_len <= 2048 the kernel's SplitKV scheduling and the relu+sum
of weighted dot products coincidentally preserved top-K rank, but for
context_len > 2048 top-K parity collapsed (~0% match vs reference) and
DSv3.2 + MTP draft acceptance dropped sharply on MI355X.
Fix: view the K-region slice as the platform FP8 dtype before passing
it to the JIT kernel. The scale region is unchanged (already viewed as
torch.float32 on the next line). After the fix the wrapper produces
byte-identical output to the proven-correct path
(`aiter.ops.cache.indexer_k_quant_and_cache` + `aiter.ops.shuffle.shuffle_weight`)
which is what AITER's own benchmark `bench_deepgemm_attention.py` uses.
Verification on MI355X (heads=64, hd=128, block=64, Preshuffle=True,
ChunkK=256, ue8m0 scales) - kernel logits vs torch reference:
ctx BUGGY topk FIXED topk max_rel(FIXED)
128 1.000 -> 1.000 1.0e-3
512 1.000 -> 1.000 8.8e-3
1024 1.000 -> 1.000 7.7e-3
2048 1.000 -> 1.000 3.4e-2
2049 0.814 -> 1.000 1.3e-2
4096 0.002 -> 1.000 3.3e-2
8192 0.002 -> 1.000 1.5e-2
16384 0.005 -> 1.000 3.6e-1
Verified end-to-end through the production wrapper
`indexer_k_quant_and_cache_triton` mounted over the original source -
all measured contexts up to 16384 reach top-K=1.000 vs the bf16-derived
torch reference.
Cross-verified with a 3-of-3 LLM swarm (gpt-5.1, Claude-Haiku-4.5,
GPT-oss-20B) - unanimous TRUE / HIGH confidence on the bug diagnosis
and the fix.
Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR adds an FP8-cast fix to the
rebase-on-upstream-main-v2branch that backs upstream PR vllm-project#41008. It is meant to be merged into Frida's branch so that vllm-project#41008 picks it up — not as a competing upstream PR.The fix corrects the DeepSeek-V3.2 sparse-MLA + MTP top-K accuracy collapse for
context_len > 2048on AMD MI355X by patching an FP8-cast bug inindexer_k_quant_and_cache_triton.Root cause
The Triton kernel
_indexer_k_quant_and_cache_kernelwrites quantized values via:The Python wrapper
indexer_k_quant_and_cache_tritonpreviously sliced the uint8-typedkv_cacheto extract the K region without re-viewing it as FP8:so the kernel saw
kv_cache_ptr.type.element_ty == tl.uint8. Triton's<float32>.to(tl.uint8)is integer truncation (e.g.1.7 -> 1,-0.3 -> 255wrap-around) — not FP8 quantization. The kernel wrote arbitrary integer bytes into the cache, and the downstream_gluon_deepgemm_fp8_paged_mqa_logits_preshufflekernel then read those bytes as FP8 e4m3fnuz, producing garbled K values.For
context_len <= 2048the noise happened to preserve top-K rank often enough that the regression was hard to spot, but beyond 2048 the dominant logits no longer aligned with the reference and top-K parity collapsed (~0% match), which is exactly the MTP draft-acceptance cliff observed in production. The original commit message of the SHUFFLE indexer change already noted a 3pp GSM8K regression "under investigation" — this is that regression's root cause.The fix
current_platform.fp8_dtype()returnstorch.float8_e4m3fnuzon AMD ROCm andtorch.float8_e4m3fnon NVIDIA CUDA, matching the kernel'sIS_FNUZconstant. The scale region is unchanged — already viewed astorch.float32on the next line.Validation
Byte-equivalence vs proven-correct path
Same input K (bf16), block=64, head_dim=128:
shuffle_weightshuffle_weightThe FIXED path produces byte-identical output to
aiter.ops.cache.indexer_k_quant_and_cachefollowed byaiter.ops.shuffle.shuffle_weight— the exact path used by AITER's official benchmarkbench_deepgemm_attention.py.End-to-end logits parity (MI355X, single GPU)
Setup: heads=64, hd=128, block=64,
Preshuffle=True,KVBlockSize=64,ChunkK=256, ue8m0 scales, single batch, single next_n. Reference: bf16-derived torch ground truth with the same per-128-group ue8m0 quantization the kernel uses.The 2049 boundary is the SplitKV/multi-block transition where the BUGGY path's accidental top-K agreement breaks down — exactly matching the user-observed MTP regression threshold.
Production-wrapper validation
Same harness, but mounting the patched
rocm_aiter_mla_sparse.pyover the original viadocker -vand callingindexer_k_quant_and_cache_triton(...)(not the kernel directly):Cross-LLM verification
Verified the bug diagnosis and the fix with a 3-of-3 LLM swarm (gpt-5.1 / Claude-Haiku-4.5 / GPT-oss-20B over the AMD LLM gateway) — unanimous TRUE / HIGH confidence on both.
Test plan
shuffle_weightCompanion PR
A parallel PR (#1 in
frida-andersson/vllm) targets therocm/gluon-preshuffle-indexerbranch with the same one-line fix.Signed-off-by: Markus Hartikainen markus.hartikainen@amd.com