Skip to content

[DCP][Bugfix][CI] Fix accuracy issue of DCP when using FLASH_ATTN_MLA#30309

Merged
LucasWilkinson merged 1 commit intovllm-project:mainfrom
FENP:dcp-interleave-fix
Dec 9, 2025
Merged

[DCP][Bugfix][CI] Fix accuracy issue of DCP when using FLASH_ATTN_MLA#30309
LucasWilkinson merged 1 commit intovllm-project:mainfrom
FENP:dcp-interleave-fix

Conversation

@FENP
Copy link
Copy Markdown
Contributor

@FENP FENP commented Dec 9, 2025

Purpose

#25049 add MTP support for DCP with FA3, but this only works when cp_kv_cache_interleave_size=1.

For FA3 backend, we should check cp_kv_cache_interleave_size and set supports_dcp_with_varlen accordingly.

Test Plan

vllm serve deepseek-ai/DeepSeek-V2-Lite-Chat/ --gpu-memory-utilization 0.9 --tensor-parallel-size 4 --decode-context-parallel-size 4 --cp-kv-cache-interleave-size 64
python ./tests/evals/gsm8k/gsm8k_eval.py

Test Result

  • Main
Running GSM8K evaluation: 1319 questions, 5-shot
Evaluating: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [00:40<00:00, 32.62it/s]

Results:
Accuracy: 0.028
Invalid responses: 0.007
Total latency: 40.447 s
Questions per second: 32.611
Total output tokens: 189017
Output tokens per second: 4673.216
  • This PR
Running GSM8K evaluation: 1319 questions, 5-shot
Evaluating: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [00:45<00:00, 28.81it/s]

Results:
Accuracy: 0.639
Invalid responses: 0.002
Total latency: 45.794 s
Questions per second: 28.803
Total output tokens: 161057
Output tokens per second: 3516.982

cc @LucasWilkinson @minosfuture @pisceskkk


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
@FENP FENP requested a review from pavanimajety as a code owner December 9, 2025 05:47
@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

@FENP FENP changed the title [CI][Bugfix] Fix accuracy issue of DCP when using FLASH_ATTN_MLA [DCP][Bugfix][CI] Fix accuracy issue of DCP when using FLASH_ATTN_MLA Dec 9, 2025
@mergify mergify bot added the v1 label Dec 9, 2025
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses an accuracy issue in Decoded Context Parallelism (DCP) when used with FlashAttention MLA and a cp_kv_cache_interleave_size greater than 1. The fix correctly disables supports_dcp_with_varlen in this configuration, forcing requests into the prefill path and resolving the bug. The code change is correct and directly targets the issue. My main feedback is to enhance the test coverage to include the specific configuration that was manually tested and shown to be fixed, ensuring this scenario is covered by CI to prevent future regressions.

CPTestSettings.detailed(dcp_multipliers=[1]),
CPTestSettings.detailed(
dcp_multipliers=[0.5, 1], cp_kv_cache_interleave_size=64
dcp_multipliers=[0.5],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The PR description shows a manual test with dcp_size == tp_size (i.e., dcp_multiplier=1) and cp_kv_cache_interleave_size=64 to demonstrate the fix. However, this specific test case seems to be missing from the automated tests after the changes. To prevent future regressions, it would be beneficial to include this configuration in the test suite. You can achieve this by adding 1 to the dcp_multipliers list for the test case with cp_kv_cache_interleave_size=64.

Suggested change
dcp_multipliers=[0.5],
dcp_multipliers=[0.5, 1],

Copy link
Copy Markdown
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@LucasWilkinson LucasWilkinson enabled auto-merge (squash) December 9, 2025 06:05
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Dec 9, 2025
@LucasWilkinson LucasWilkinson merged commit 67475a6 into vllm-project:main Dec 9, 2025
58 of 59 checks passed
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
…vllm-project#30309)

Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants