[Perf] Add decode full-graph support to FlashInfer-MLA backend#26313
[Perf] Add decode full-graph support to FlashInfer-MLA backend#26313LucasWilkinson merged 1 commit intovllm-project:mainfrom
Conversation
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
There was a problem hiding this comment.
Code Review
This pull request correctly enables full CUDA graph support for decode operations in the FlashInfer-MLA attention backend. The change is implemented by creating a new FlashInferMLAMetadataBuilder class that inherits from MLACommonMetadataBuilder and sets the cudagraph_support attribute to AttentionCGSupport.UNIFORM_BATCH. The FlashInferMLABackend is then updated to use this new builder. The approach is clean, follows the existing design patterns in the codebase, and seems to correctly enable the feature as described. The changes are minimal and well-targeted. I found no issues of high or critical severity.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
…project#26313) Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
…project#26313) Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
…project#26313) Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
…project#26313) Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Purpose
The annotation was missing from FlashInfer-MLA while the implementation has support.
Running DSR1-FP4 on 4xB200 gets me 97 TPS:
I also tested on a local development branch for MTP containing #25984, and #25987.
On that branch, with 3 MTP speculative tokens, I get 165 TPS and passing GSM8k evals.
Test Plan
GSM8k run as follows:
Test Result
Matches the baseline: