[Core] Enable FP8 KV cache with Decode Context Parallel (DCP) for MLA#34795
[Core] Enable FP8 KV cache with Decode Context Parallel (DCP) for MLA#34795grimulkan wants to merge 1 commit intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request enables the use of FP8 KV cache with Decode Context Parallel (DCP) for MLA, which was previously unsupported. The changes are well-structured, including restructuring the decode path to quantize after all-gathering, updating the prefill path to use a new gather_and_maybe_dequant_cache operation, and adding necessary metadata and test coverage. The implementation looks solid. I have one suggestion to add an explicit guard for a known unsupported configuration to improve user experience and error reporting.
There was a problem hiding this comment.
Pull request overview
This PR enables the combination of FP8 KV cache (kv_cache_dtype=fp8) with Decode Context Parallel (DCP) > 1 for MLA attention, which was previously blocked by hard assertions. The changes maintain numerical correctness through storage-only FP8 (no new FP8 compute paths) and are motivated by allowing users to benefit from both FP8 memory savings and DCP's latency/memory improvements simultaneously.
Changes:
- Restructured the decode Q path to allgather in BF16 first, then optionally quantize to FP8 post-gather for backends with
supports_quant_query_input=True - Replaced
cp_gather_cachewithgather_and_maybe_dequant_cachefor FP8 KV cache in the prefill DCP gather path to handle dtype mismatches - Added
padded_local_token_to_seqandpadded_local_chunk_total_tokenmetadata fields toChunkedContextMetadatafor FP8 DCP support - Passed
k_scalethrough to the DCP prefill path instead of hardcodingNone - Added a guard for the unsupported
use_fp8_prefill+ DCP > 1 combination - Extended test parameterization to include FP8 DCP testing
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
| vllm/model_executor/layers/attention/mla_attention.py | Restructured decode Q path for DCP+FP8 compatibility, added FP8-aware prefill gather logic, added metadata fields for FP8 DCP support, passed k_scale to DCP path, and guarded against unsupported FP8 prefill with DCP |
| tests/distributed/test_context_parallel.py | Added kv_cache_dtype parameter support and FP8 DCP test configuration for DeepSeek-V2-Lite-Chat |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
25117d4 to
7ba2d44
Compare
|
This approach does differ slightly from @LucasWilkinson in that it uses bf16 all-gathers for Q, which is not the best bandwidth efficiency. I can see the problem is because my Triton MLA fp8 PR has A better approach would be to switch on supports_quant_query_input: But I would need someone with sm90 or sm100 to help test the EDIT: I found a simple way to include the merge without impacting compatibility. I still cannot test the sm90/100 path, but now the only difference between the 2 paths is an additional fp8 quantization, so we should have the best of both worlds with lower risk on the untested path. |
d2055f1 to
970b7b2
Compare
|
Some speed/VRAM benchmarks on sm120. Kimi K2.5 on RTX 6000 Pro** (native int4 experts, Marlin gemm, Triton MLA)
The fp8 versions also require #34597 on sm120 |
|
confirming that this is working on 8x RTX PRO AMD Turin: NCCL_P2P_LEVEL=SYS VLLM_LOG_STATS_INTERVAL=1 NCCL_GRAPH_FILE=/mnt/nccl_graph_opt.xml VLLM_TEST_FORCE_FP8_MARLIN=1 VLLM_MARLIN_USE_ATOMIC_ADD=1 VLARLIN_INPUT_DTYPE=fp8 vllm serve moonshotai/Kimi-K2.5 --served-model-name Kimi-K2.5 --trust-remote-code --host 0.0.0.0 --port 5000 --tensor-parallel-size 8 --pipeline-parallel-size 1 --enable-chunked-prefill --enable-prefix-caching --load-format fastsafetensors --tool-call-parser kimi_k2 --enable-auto-tool-choice --reasoning-parser kimi_k2 --async-scheduling --gpu-memory-utilization 0.93 --max-num-batched-tokens 4096 --mm-processor-cache-gb 0 --mm-encoder-tp-mode weights --language-model-only --attention-backend TRITON_MLA --kv-cache-dtype fp8 GPU KV cache size: 449,600 tokens when --decode-context-parallel-size 8 is used (more KV cache): speed: 66tok/sec |
Cherry-picked from: - PR vllm-project#34597: FP8 KV cache support for Triton MLA decode attention - PR vllm-project#34795: Enable FP8 KV cache with Decode Context Parallel (DCP) for MLA Changes: - Add fp8/fp8_e4m3 to TritonMLABackend.supported_kv_cache_dtypes - Thread k_scale/v_scale through decode attention kernel - Add FP8 dequant-on-load in Triton kernels - Enable DCP + FP8 KV cache combination - Add gather_and_maybe_dequant_cache for FP8 DCP prefill path
Previously, MLA attention blocked the combination of FP8 KV cache (kv_cache_dtype=fp8) with DCP > 1 via hard asserts. This patch enables the combination by: - Restructuring the decode Q path to allgather in BF16, then optionally quantize to FP8 post-gather for backends with supports_quant_query_input - Replacing cp_gather_cache (dtype-strict) with gather_and_maybe_dequant_cache for FP8 KV cache in the prefill DCP gather path - Passing k_scale through to the DCP prefill path (was hardcoded None) - Adding a clear guard for the unsupported use_fp8_prefill + DCP > 1 case - Adding FP8 DCP test parameterization to test_context_parallel.py Signed-off-by: grimulkan <grimulkan@gmail.com>
|
Rebased, no change in performance or functionality. |
Previously, MLA attention blocked the combination of FP8 KV cache (kv_cache_dtype=fp8) with DCP > 1 via hard asserts. This patch enables the combination by:
Purpose
MLA attention previously blocked the combination of FP8 KV cache (
kv_cache_dtype=fp8) with Decode Context Parallel (DCP) > 1 via two hard asserts:assert not fp8_attention, "DCP not support fp8 kvcache now."assert k_scale is None, "DCP not support scaled kvcache now."This meant users had to choose between FP8 KV cache (memory savings) and DCP (more memory savings, latency reduction). This PR enables both to work together, maintaining numerical correctness through storage-only FP8 (no new FP8 compute paths).
Changes
mla_attention.py:
Decode Q restructure: DCP > 1 always allgathers Q in BF16 first, then optionally quantizes to FP8 post-gather if
supports_quant_query_input=True. This avoids the type mismatch where_DecodeConcatQuantFP8produces a single FP8 tensor incompatible with the DCP tuple->cat->allgather flow. DCP = 1 path is unchanged.FP8-aware prefill gather:
cp_gather_cachehas a strictTORCH_CHECK(src.dtype == dst.dtype)that crashes when FP8 cache meets BF16 workspace. For FP8 KV cache (excludingfp8_ds_mla), the code now callsgather_and_maybe_dequant_cachewhich fuses gather + FP8->BF16 dequantization. Non-FP8 path continues to usecp_gather_cache.Metadata additions: Added
padded_local_token_to_seqandpadded_local_chunk_total_tokenfields toChunkedContextMetadata, computed inbuild(), required bygather_and_maybe_dequant_cache.k_scalepassthrough:forward_mhanow passes the realk_scaleto the DCP prefill path instead of a hardcodedNone.Guard for
use_fp8_prefill+ DCP > 1: Added a clear assert with actionable error message. This combination would require FP8 workspace allocation (only for sm10x + FlashInfer/TRT-LLM + use_prefill_query_quantization), and not supported.test_context_parallel.py:
kv_cache_dtypesupport toCPTestOptions/CPTestSettings.detailed()and added a new test entry for DeepSeek-V2-Lite-Chat withdcp=4, kv_cache_dtype=fp8.Test Plan
New test:
test_context_parallel.py::test_cp_generationwithkv_cache_dtype="fp8",dcp_size=4— GSM8K 256-question 5-shot accuracy eval with DeepSeek-V2-Lite-ChatRegression tests:
test_context_parallel.py::test_cp_generationwithkv_cache_dtype=auto,dcp_size=4— existing DCP testtest_mla_backends.py::test_backend_correctness— MLA backend unit tests (DCP=1 forward paths)Test Results
Environment: 16x GPUs, sm120, TritonMLA backend (the only MLA backend that works on sm120)
test_cp_generation(regression)tp=4, dcp=4, kv_cache_dtype=autotest_cp_generation(new)tp=4, dcp=4, kv_cache_dtype=fp8test_backend_correctnessRuntimeError: Error Internal) and FlashInfer MLA (XQA MLA only supports fp8 on SM120). Not related to this PR.Full lm_eval GSM8K results (Kimi-K2.5, tp=16, TritonMLA, 5-shot):
dcp=16, kv_cache_dtype=auto(baseline)dcp=1, kv_cache_dtype=fp8(baseline)dcp=16, kv_cache_dtype=fp8(this PR)FP8 + DCP=16 accuracy matches both baselines within error margins.
Known Limitations
fp8_ds_mla+ DCP > 1: Not supported (different storage format with embedded block scales). Falls through tocp_gather_cachewhich will dtype-check at runtime.use_fp8_prefill+ DCP > 1: Explicitly guarded with assert. Would require FP8 workspace allocation; currently only possible with sm10x + FlashInfer/TRT-LLM + user selecting use_prefill_query_quantization.supports_quant_query_input=Truebackends: Post-gather FP8 quant path added but not tested on this (sm120) hardware (requires sm90a for FlashMLA/CutlassMLA). The path uses an additionalops.scaled_fp8_quantwhich is a well-tested primitive, with all other commands being the same as the Triton MLA sm120 path. Risk is minimal.gather_and_maybe_dequant_cache: Baseline vllm has hardcodedhead_dim == 576constraint, currently limiting FP8 DCP to DeepSeek V2/V3/R1 family models. This constraint is duplicated in this PR.Note
To run this PR on sm120, it also requires that Triton MLA support
kv-cache-dtype fp8from #34597 since that's the only backend that supports it. At this time, that PR is not yet merged, but the two features are independent and can be merged separately.Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.