[Perf] Refactor cudagraph_support to enable full CUDA graphs for spec decoding with FlashInfer#28479
Conversation
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
|
Documentation preview: https://vllm--28479.org.readthedocs.build/en/28479/ |
There was a problem hiding this comment.
Code Review
This pull request introduces a well-designed refactoring to enable more flexible and dynamic CUDA graph support for attention backends. By making cudagraph_support a private member and introducing a new get_cudagraph_support method, the code now dynamically determines the CUDA graph capability on a per-backend, per-KV-group basis. This change is crucial for enabling full CUDA graph support for speculative decoding with FlashInfer on specific hardware like Blackwell. The updates to _check_and_update_cudagraph_mode and the corresponding documentation changes are clear and correct. Overall, this is a solid performance enhancement with clean implementation.
LucasWilkinson
left a comment
There was a problem hiding this comment.
LGTM
Id like to work towards reverting #27427 (and move back to this being an instance property) in the future; but we need broader cudagraph refactors to get there
vadiklyutiy
left a comment
There was a problem hiding this comment.
Before we use use_trtllm_attention for checking both prefill and decode.
Right now seem use_trtllm_attention is using for checking prefill only and can_use_trtllm_attention .
May we refactor:
- use proper name like
use_trtllm_prefill_attnanduse_trtllm_decode_attn - remove from
use_trtllm_attentionprocessing of decode case
Maybe it's worth to do in separate PR
vadiklyutiy
left a comment
There was a problem hiding this comment.
Before we use use_trtllm_attention for checking both prefill and decode.
Right now seem use_trtllm_attention is using for checking prefill only and can_use_trtllm_attention .
May we refactor:
- use proper name like
use_trtllm_prefill_attnanduse_trtllm_decode_attn - remove from
use_trtllm_attentionprocessing of decode case
Maybe it's worth to do in separate PR
vadiklyutiy
left a comment
There was a problem hiding this comment.
One more style
Is there some reason to hold [_]cudagraph_support and get_cudagraph_support in *MetaBuilder classes, maybe *Backend(AttentionBackend) is better place?
|
… decoding with FlashInfer (vllm-project#28479) Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Signed-off-by: George D. Torres <gdavtor@gmail.com>
… decoding with FlashInfer (vllm-project#28479) Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Purpose
Revised implementation of #26937
This PR makes
_cudagraph_supporta private member and usesget_cudagraph_support(vllm_config, kv_cache_spec). Also updates_check_and_update_cudagraph_modeto consider support per-backend, per-kv-group.TRTLLM-gen kernels support full cuda graphs, but are only used with FlashInfer on Blackwell under certain conditions.
It might not be safe to change FlashInfer's cudagraph_support to UNIFORM_BATCH always, but we can still set it when we know TRTLLM-gen backend will be used.
Also update the docs to reflect the FlashInfer cuda graph compatibility, and fill in the missing entry for FlashInferMLA.
FIX #26856
Test Plan
See #26937 for functional correctness testing / benchmarking. Rerunning on this branch gives the same results.
Local test run passes for
tests/v1/attention.