Add fast decode plan for flashinfer mla#3987
Merged
zhyncs merged 6 commits intosgl-project:mainfrom Mar 3, 2025
Merged
Conversation
zhyncs
approved these changes
Mar 3, 2025
merrymercy
reviewed
Mar 3, 2025
| forward_mode: ForwardMode, | ||
| encoder_lens: Optional[torch.Tensor], | ||
| spec_info: Optional[SpecInfo], | ||
| **kwargs, |
Contributor
There was a problem hiding this comment.
Do not use kwargs. It makes the code more unreadable because we do not know what exact arguments are.
Is it possible to specify it more clearly?
| forward_mode: ForwardMode, | ||
| encoder_lens: Optional[torch.Tensor], | ||
| spec_info: Optional[SpecInfo], | ||
| **kwargs, |
|
|
||
| def get_model_worker_batch(self): | ||
| if self.forward_mode.is_decode_or_idle(): | ||
| decode_seq_lens = self.seq_lens.cpu() |
Contributor
There was a problem hiding this comment.
This will slowdown other things (e.g., speculative decoding where overlap scheduler is turned off). Can we only do this when needed?
| self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc) | ||
| self.positions[:raw_num_token].copy_(forward_batch.positions) | ||
| if forward_batch.decode_seq_lens_cpu is not None: | ||
| self.seq_lens_cpu[:raw_bs].copy_(forward_batch.decode_seq_lens_cpu) |
Contributor
There was a problem hiding this comment.
if it is a CPU tensor, it does not need to go through these CUDA graph things.
merrymercy
added a commit
that referenced
this pull request
Mar 3, 2025
This reverts commit fa56106.
6 tasks
aoshen524
pushed a commit
to aoshen524/sglang
that referenced
this pull request
Mar 10, 2025
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
When using flashinfer mla backend and cuda graph together, graph replay will be hanged due to transmission of indptr tensors between cpu and gpu in
BatchMLAPagedAttentionWrapper.plan.This PR fixes this issue by adding a new
decode_seq_len_cpuin forward batch and customizing a faster decode plan for graph replaying.Also, some issues (#3906, #3917) points out current flashinfer mla backend behaves worse than triton in long output cases. Hopefully this PR will fix this problem.
Modifications
decode_seq_len_cpuin forward batch, which puts the information of seq_lens on cpu in advance.fast_mla_decode_planthat can avoid transmitting indptr tensors from gpu to cpu during graph replaying.Accuracy
Launching
GSM8K
MMLU
Benchmark
To better discover the improvement of this PR, the benchmarks are run on long output workloads (so number of graph replaying can be increased) with Deepseek-v2-lite. Machine is Nvidia H200. Each benchmark is run five times and the average throughput is computed. After this PR, throughput of flashinfer mla on these workloads can be improved by 1% to 2%.
To Launch:
Input-4096-Output-2048 (same workload as #3917)
Input-180-Output-400 (same workload as #3906)
Input-100-Output-2000
Profiler Result
After profiling with torch profiler, we can see the time of waiting for memcpyAsync is removed. Since MLA with absorbed is compute bound and GPU is fully utilized, its influence on e2e throughput is not obvious.
Before this PR:

After this PR:

Checklist