[BugFix] Fix IMA FlashMLA full cuda-graph and DP + Update FlashMLA#21691
Conversation
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Code Review
This pull request correctly addresses a potential issue with FlashMLA under full CUDA graph capture, especially in distributed settings. By moving the buffer allocation from a lazy, in-place approach to a pre-allocation strategy in the __init__ method, the code becomes more robust and compliant with CUDA graph requirements. The changes are logical and well-implemented. I have one suggestion to replace a magic number with a constant to improve long-term maintainability.
Signed-off-by: Tyler Michael Smith <tysmith@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
cfe7b52 to
7492f99
Compare
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
7492f99 to
1035a64
Compare
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
yewentao256
left a comment
There was a problem hiding this comment.
vllm serve deepseek-ai/DeepSeek-V2-Lite --port 9256 --enable-expert-parallel --data-parallel-size 2 --trust-remote-code -O '{"full_cuda_graph": true}' --cuda-graph-sizes 16 32 64 128 256 512
Originally:
(EngineCore_0 pid=94527) AssertionError
(EngineCore_1 pid=94528) answer = run_method(self.driver_worker, method, args, kwargs)
(EngineCore_1 pid=94528) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_1 pid=94528) File "/home/wentao/vllm/vllm/utils/__init__.py", line 2948, in run_method
(EngineCore_1 pid=94528) return func(*args, **kwargs)
(EngineCore_1 pid=94528) ^^^^^^^^^^^^^^^^^^^^^
(EngineCore_1 pid=94528) File "/home/wentao/vllm/vllm/v1/worker/gpu_worker.py", line 330, in compile_or_warm_up_model
(EngineCore_1 pid=94528) self.model_runner._dummy_run(
(EngineCore_1 pid=94528) File "/home/wentao/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(EngineCore_1 pid=94528) return func(*args, **kwargs)
(EngineCore_1 pid=94528) ^^^^^^^^^^^^^^^^^^^^^
(EngineCore_1 pid=94528) File "/home/wentao/vllm/vllm/v1/worker/gpu_model_runner.py", line 2206, in _dummy_run
(EngineCore_1 pid=94528) .build_for_cudagraph_capture(common_attn_metadata)
(EngineCore_1 pid=94528) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_1 pid=94528) File "/home/wentao/vllm/vllm/v1/attention/backends/mla/common.py", line 580, in build_for_cudagraph_capture
(EngineCore_1 pid=94528) return self.build(0, m)
(EngineCore_1 pid=94528) ^^^^^^^^^^^^^^^^
(EngineCore_1 pid=94528) File "/home/wentao/vllm/vllm/v1/attention/backends/mla/common.py", line 705, in build
(EngineCore_1 pid=94528) decode_metadata = self._build_decode(
(EngineCore_1 pid=94528) ^^^^^^^^^^^^^^^^^^^
(EngineCore_1 pid=94528) File "/home/wentao/vllm/vllm/v1/attention/backends/mla/flashmla.py", line 100, in _build_decode
(EngineCore_1 pid=94528) assert n <= self.cg_buf_num_splits.size(0)
(EngineCore_1 pid=94528) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_1 pid=94528) AssertionError
Now:
(APIServer pid=126325) INFO: Started server process [126325]
(APIServer pid=126325) INFO: Waiting for application startup.
(APIServer pid=126325) INFO: Application startup complete.So I think this PR fixed the issue, thanks for the work! @tlrmchlsmth Could you trigger CI?
…llm-project#21691) Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
…llm-project#21691) Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: Noam Gat <noamgat@gmail.com>
…llm-project#21691) Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: Paul Pak <paulpak58@gmail.com>
…llm-project#21691) Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: Diego-Castan <diego.castan@ibm.com>
…llm-project#21691) Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
…llm-project#21691) Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
…llm-project#21691) Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: Xiao Yu <xiao.yu@amd.com>
…llm-project#21691) Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.Purpose
Merge: vllm-project/FlashMLA#3 first
Fix an IMA that occurs when using FlashMLA with full-cudagraphs and wide-ep
Also updates FlashMLA (i.e. #17027) since the FlashMLA changes were made on top of that. #17027 was back-burnered since it shows a slight slowdown in the TP attention case but should provide speedup for DP attention.
Test Plan
Test was failing on an llm-d benchmark
Test Result
Fixes the llm-d benchmark
(Optional) Documentation Update