Skip to content

[Core] Enable FP8 KV cache with Decode Context Parallel (DCP) for MLA#34795

Open
grimulkan wants to merge 1 commit intovllm-project:mainfrom
grimulkan:dcp-fp8-mla
Open

[Core] Enable FP8 KV cache with Decode Context Parallel (DCP) for MLA#34795
grimulkan wants to merge 1 commit intovllm-project:mainfrom
grimulkan:dcp-fp8-mla

Conversation

@grimulkan
Copy link
Contributor

@grimulkan grimulkan commented Feb 18, 2026

Previously, MLA attention blocked the combination of FP8 KV cache (kv_cache_dtype=fp8) with DCP > 1 via hard asserts. This patch enables the combination by:

  • Restructuring the decode Q path to allgather in BF16, then optionally quantize to FP8 post-gather for backends with supports_quant_query_input
  • Replacing cp_gather_cache (dtype-strict) with gather_and_maybe_dequant_cache for FP8 KV cache in the prefill DCP gather path
  • Passing k_scale through to the DCP prefill path (was hardcoded None)
  • Adding a clear guard for the unsupported use_fp8_prefill + DCP > 1 case
  • Adding FP8 DCP test parameterization to test_context_parallel.py

Purpose

MLA attention previously blocked the combination of FP8 KV cache (kv_cache_dtype=fp8) with Decode Context Parallel (DCP) > 1 via two hard asserts:

  • Decode path: assert not fp8_attention, "DCP not support fp8 kvcache now."
  • Prefill DCP gather: assert k_scale is None, "DCP not support scaled kvcache now."

This meant users had to choose between FP8 KV cache (memory savings) and DCP (more memory savings, latency reduction). This PR enables both to work together, maintaining numerical correctness through storage-only FP8 (no new FP8 compute paths).

Changes

mla_attention.py:

  1. Decode Q restructure: DCP > 1 always allgathers Q in BF16 first, then optionally quantizes to FP8 post-gather if supports_quant_query_input=True. This avoids the type mismatch where _DecodeConcatQuantFP8 produces a single FP8 tensor incompatible with the DCP tuple->cat->allgather flow. DCP = 1 path is unchanged.

  2. FP8-aware prefill gather: cp_gather_cache has a strict TORCH_CHECK(src.dtype == dst.dtype) that crashes when FP8 cache meets BF16 workspace. For FP8 KV cache (excluding fp8_ds_mla), the code now calls gather_and_maybe_dequant_cache which fuses gather + FP8->BF16 dequantization. Non-FP8 path continues to use cp_gather_cache.

  3. Metadata additions: Added padded_local_token_to_seq and padded_local_chunk_total_token fields to ChunkedContextMetadata, computed in build(), required by gather_and_maybe_dequant_cache.

  4. k_scale passthrough: forward_mha now passes the real k_scale to the DCP prefill path instead of a hardcoded None.

  5. Guard for use_fp8_prefill + DCP > 1: Added a clear assert with actionable error message. This combination would require FP8 workspace allocation (only for sm10x + FlashInfer/TRT-LLM + use_prefill_query_quantization), and not supported.

test_context_parallel.py:

  1. FP8 DCP test parameterization: Added kv_cache_dtype support to CPTestOptions/CPTestSettings.detailed() and added a new test entry for DeepSeek-V2-Lite-Chat with dcp=4, kv_cache_dtype=fp8.

Test Plan

New test:

  • test_context_parallel.py::test_cp_generation with kv_cache_dtype="fp8", dcp_size=4 — GSM8K 256-question 5-shot accuracy eval with DeepSeek-V2-Lite-Chat
  • End-to-end test with lm_eval (GSM8K) with Kimi K2.5 on sm120 using new settings

Regression tests:

  • test_context_parallel.py::test_cp_generation with kv_cache_dtype=auto, dcp_size=4 — existing DCP test
  • test_mla_backends.py::test_backend_correctness — MLA backend unit tests (DCP=1 forward paths)
  • End-to-end test with lm_eval (GSM8K) with Kimi K2.5 on sm120 using existing previously supported settings

Test Results

Environment: 16x GPUs, sm120, TritonMLA backend (the only MLA backend that works on sm120)

Test Config Result
test_cp_generation (regression) tp=4, dcp=4, kv_cache_dtype=auto PASS (GSM8K accuracy ≥ 0.64)
test_cp_generation (new) tp=4, dcp=4, kv_cache_dtype=fp8 PASS (GSM8K accuracy ≥ 0.64)
test_backend_correctness 48 parameterizations, DCP=1 16 pass, 32 fail (pre-existing) — all failures are pre-existing sm120 backend issues: Cutlass MLA (RuntimeError: Error Internal) and FlashInfer MLA (XQA MLA only supports fp8 on SM120). Not related to this PR.

Full lm_eval GSM8K results (Kimi-K2.5, tp=16, TritonMLA, 5-shot):

Config exact_match (flexible) exact_match (strict)
dcp=16, kv_cache_dtype=auto (baseline) 0.9363 ± 0.0067 0.9363 ± 0.0067
dcp=1, kv_cache_dtype=fp8 (baseline) 0.9378 ± 0.0067 0.9371 ± 0.0067
dcp=16, kv_cache_dtype=fp8 (this PR) 0.9371 ± 0.0067 0.9371 ± 0.0067

FP8 + DCP=16 accuracy matches both baselines within error margins.

Known Limitations

  • fp8_ds_mla + DCP > 1: Not supported (different storage format with embedded block scales). Falls through to cp_gather_cache which will dtype-check at runtime.
  • use_fp8_prefill + DCP > 1: Explicitly guarded with assert. Would require FP8 workspace allocation; currently only possible with sm10x + FlashInfer/TRT-LLM + user selecting use_prefill_query_quantization.
  • supports_quant_query_input=True backends: Post-gather FP8 quant path added but not tested on this (sm120) hardware (requires sm90a for FlashMLA/CutlassMLA). The path uses an additional ops.scaled_fp8_quant which is a well-tested primitive, with all other commands being the same as the Triton MLA sm120 path. Risk is minimal.
  • gather_and_maybe_dequant_cache: Baseline vllm has hardcoded head_dim == 576 constraint, currently limiting FP8 DCP to DeepSeek V2/V3/R1 family models. This constraint is duplicated in this PR.

Note

To run this PR on sm120, it also requires that Triton MLA support kv-cache-dtype fp8 from #34597 since that's the only backend that supports it. At this time, that PR is not yet merged, but the two features are independent and can be merged separately.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copilot AI review requested due to automatic review settings February 18, 2026 10:49
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables the use of FP8 KV cache with Decode Context Parallel (DCP) for MLA, which was previously unsupported. The changes are well-structured, including restructuring the decode path to quantize after all-gathering, updating the prefill path to use a new gather_and_maybe_dequant_cache operation, and adding necessary metadata and test coverage. The implementation looks solid. I have one suggestion to add an explicit guard for a known unsupported configuration to improve user experience and error reporting.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR enables the combination of FP8 KV cache (kv_cache_dtype=fp8) with Decode Context Parallel (DCP) > 1 for MLA attention, which was previously blocked by hard assertions. The changes maintain numerical correctness through storage-only FP8 (no new FP8 compute paths) and are motivated by allowing users to benefit from both FP8 memory savings and DCP's latency/memory improvements simultaneously.

Changes:

  • Restructured the decode Q path to allgather in BF16 first, then optionally quantize to FP8 post-gather for backends with supports_quant_query_input=True
  • Replaced cp_gather_cache with gather_and_maybe_dequant_cache for FP8 KV cache in the prefill DCP gather path to handle dtype mismatches
  • Added padded_local_token_to_seq and padded_local_chunk_total_token metadata fields to ChunkedContextMetadata for FP8 DCP support
  • Passed k_scale through to the DCP prefill path instead of hardcoding None
  • Added a guard for the unsupported use_fp8_prefill + DCP > 1 combination
  • Extended test parameterization to include FP8 DCP testing

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.

File Description
vllm/model_executor/layers/attention/mla_attention.py Restructured decode Q path for DCP+FP8 compatibility, added FP8-aware prefill gather logic, added metadata fields for FP8 DCP support, passed k_scale to DCP path, and guarded against unsupported FP8 prefill with DCP
tests/distributed/test_context_parallel.py Added kv_cache_dtype parameter support and FP8 DCP test configuration for DeepSeek-V2-Lite-Chat

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@grimulkan
Copy link
Contributor Author

grimulkan commented Feb 18, 2026

This approach does differ slightly from @LucasWilkinson in that it uses bf16 all-gathers for Q, which is not the best bandwidth efficiency. I can see the problem is because my Triton MLA fp8 PR has supports_quant_query_input = False (and internally uses bf16 attention). But the other backends expect quantized Q. There may be future backends that have supports_quant_query_input = False. My current approach maximizes compatibility, but is less efficient.

A better approach would be to switch on supports_quant_query_input:

if self.impl.dcp_world_size > 1:
    if fp8_attention and self.impl.supports_quant_query_input:
        # FP8 quant first, all_gather in FP8 (half bandwidth)
        mqa_q = self._decode_concat_quant_fp8_op(
            mqa_ql_nope, mqa_q_pe, self._q_scale)
        mqa_q = get_dcp_group().all_gather(mqa_q, dim=1)
    else:
        # BF16 all_gather (TritonMLA and non-FP8 cases)
        mqa_q = torch.cat((mqa_ql_nope, mqa_q_pe), dim=-1)
        mqa_q = get_dcp_group().all_gather(mqa_q, dim=1)

But I would need someone with sm90 or sm100 to help test the supports_quant_query_input = True path in that case.

EDIT: I found a simple way to include the merge without impacting compatibility. I still cannot test the sm90/100 path, but now the only difference between the 2 paths is an additional fp8 quantization, so we should have the best of both worlds with lower risk on the untested path.

@grimulkan
Copy link
Contributor Author

Some speed/VRAM benchmarks on sm120.

Kimi K2.5 on RTX 6000 Pro** (native int4 experts, Marlin gemm, Triton MLA)

Cards TP DCP PP KV Cache Total KV Cache Space Generation Speed (@ 0 context)
8 8 8 1 fp8 3M tok 68 tok/s
8 8 1 1 fp8 380K tok 79 tok/s
8 8 8 1 bf16 1.5M tok 67 tok/s
8 8 1 1 bf16 190K tok 78 tok/s
16 16 16 1 fp8 20M tok 43 tok/s
16 16 1 1 fp8 1.25M tok 64 tok/s
16 16 16 1 bf16 10M tok 42 tok/s
16 16 1 1 bf16 638K tok 60 tok/s

The fp8 versions also require #34597 on sm120
Likely some of this would need to be rebased after #33529 is merged (the above results don't have those improvements).

@voipmonitor
Copy link
Contributor

confirming that this is working on 8x RTX PRO AMD Turin:

NCCL_P2P_LEVEL=SYS VLLM_LOG_STATS_INTERVAL=1 NCCL_GRAPH_FILE=/mnt/nccl_graph_opt.xml VLLM_TEST_FORCE_FP8_MARLIN=1 VLLM_MARLIN_USE_ATOMIC_ADD=1 VLARLIN_INPUT_DTYPE=fp8 vllm serve moonshotai/Kimi-K2.5 --served-model-name Kimi-K2.5 --trust-remote-code --host 0.0.0.0 --port 5000 --tensor-parallel-size 8 --pipeline-parallel-size 1 --enable-chunked-prefill --enable-prefix-caching --load-format fastsafetensors --tool-call-parser kimi_k2 --enable-auto-tool-choice --reasoning-parser kimi_k2 --async-scheduling --gpu-memory-utilization 0.93 --max-num-batched-tokens 4096 --mm-processor-cache-gb 0 --mm-encoder-tp-mode weights --language-model-only --attention-backend TRITON_MLA --kv-cache-dtype fp8

GPU KV cache size: 449,600 tokens
speed: 79tok/sec

when --decode-context-parallel-size 8 is used (more KV cache):
GPU KV cache size: 3,621,504 tokens

speed: 66tok/sec

ec-jt added a commit to ec-jt/vllm that referenced this pull request Mar 1, 2026
Cherry-picked from:
- PR vllm-project#34597: FP8 KV cache support for Triton MLA decode attention
- PR vllm-project#34795: Enable FP8 KV cache with Decode Context Parallel (DCP) for MLA

Changes:
- Add fp8/fp8_e4m3 to TritonMLABackend.supported_kv_cache_dtypes
- Thread k_scale/v_scale through decode attention kernel
- Add FP8 dequant-on-load in Triton kernels
- Enable DCP + FP8 KV cache combination
- Add gather_and_maybe_dequant_cache for FP8 DCP prefill path
Previously, MLA attention blocked the combination of FP8 KV cache
(kv_cache_dtype=fp8) with DCP > 1 via hard asserts. This patch enables
the combination by:

- Restructuring the decode Q path to allgather in BF16, then optionally
  quantize to FP8 post-gather for backends with supports_quant_query_input
- Replacing cp_gather_cache (dtype-strict) with gather_and_maybe_dequant_cache
  for FP8 KV cache in the prefill DCP gather path
- Passing k_scale through to the DCP prefill path (was hardcoded None)
- Adding a clear guard for the unsupported use_fp8_prefill + DCP > 1 case
- Adding FP8 DCP test parameterization to test_context_parallel.py

Signed-off-by: grimulkan <grimulkan@gmail.com>
@grimulkan
Copy link
Contributor Author

Rebased, no change in performance or functionality.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants