[Attention][DCP] Support DCP with query length > 1 (MTP) with FA3#25049
[Attention][DCP] Support DCP with query length > 1 (MTP) with FA3#25049vllm-bot merged 11 commits intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request aims to enable multi-token prediction (MTP) with decode context parallelism (DCP) for FlashAttention-3. The changes involve removing a restriction on query length for DCP and passing cp_world_size and cp_rank to the attention kernel. While the changes in vllm/v1/worker/gpu_model_runner.py are correct, there is a critical issue in vllm/v1/attention/backends/mla/flashattn_mla.py. The newly used attributes self.dcp_world_size and self.dcp_rank are not properly initialized due to an issue in the MLACommonImpl base class, which will cause a TypeError at runtime. This must be addressed for the feature to function correctly.
|
Thanks for this contribution! Just wanted to leave a reminder to update the FlashAttention |
| # assert once the custom mask is support is added to FA3. | ||
| if self.dcp_world_size > 1: | ||
| assert self.reorder_batch_threshold == 1, \ | ||
| "DCP not support reorder_batch_threshold > 1 now." |
There was a problem hiding this comment.
since only flash_attn_mla support custom mask, we can't just remove this assert right now?
There was a problem hiding this comment.
make sense. I'll make a whitelist here for FA3 MLA
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Ming Yang <minos.future@gmail.com>
Signed-off-by: Ming Yang <minos.future@gmail.com>
…eq_len Signed-off-by: Ming Yang <minos.future@gmail.com>
965cdab to
9c0176b
Compare
Signed-off-by: Ming Yang <minos.future@gmail.com>
Signed-off-by: Ming Yang <minos.future@gmail.com>
Signed-off-by: Ming Yang <minos.future@gmail.com>
Signed-off-by: Ming Yang <minos.future@gmail.com>
Signed-off-by: Ming Yang <minos.future@gmail.com>
vllm/v1/attention/backends/utils.py
Outdated
| # Needed by CrossAttentionBuilder | ||
| encoder_seq_lens: Optional[np.ndarray] = None | ||
|
|
||
| cp_seq_lens: Optional[torch.Tensor] = None |
There was a problem hiding this comment.
sounds good. lemme keep the dcp prefix.
LucasWilkinson
left a comment
There was a problem hiding this comment.
overall looks good to me; left one nit
Signed-off-by: Ming Yang <minos.future@gmail.com>
…lm-project#25049) Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: yang926 <yang926@naver.com>
…lm-project#25049) Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: Dhruvil Bhatt <bhattdbh@amazon.com>
…lm-project#25049) Signed-off-by: Ming Yang <minos.future@gmail.com>
…lm-project#25049) Signed-off-by: Ming Yang <minos.future@gmail.com>
…lm-project#25049) Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
…lm-project#25049) Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
…lm-project#25049) Signed-off-by: Ming Yang <minos.future@gmail.com>
| query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] | ||
| max_query_len = query_lens_cpu.max().item() | ||
| max_seq_len = seq_lens_cpu.max().item() | ||
| max_seq_len = seq_lens_device.max().item() |
There was a problem hiding this comment.
@minosfuture this introduces a CPU sync and will be problematic for performance, especially in the async scheduling case. It looks like it should not be too hard to ensure we obtain this from the cpu-side tensor in the DCP case too?
…lm-project#25049) Signed-off-by: Ming Yang <minos.future@gmail.com>
|
@minosfuture support the TP+DP+MTP combination? |
Purpose
Combined with vllm-project/flash-attention#93, this is to enable MTP (multi-token prediction) with DCP (decode context parallelism). It also allows prefill/decode to be mixed in a batch.
See vllm-project/flash-attention#93 for the implementation and solution details. Here we just need to pass the cp world size and cp rank.
Test Plan
Test Result
Benchmark
Expand for details:
Metric Details
#### With MTP and TP8,DCP4With MTP and TP8,DCP8
With MTP and TP8
With TP8, DCP8
With TP8, DCP4
With TP8
LM Eval
with tp8dcp8+mtp
local-completions (model=deepseek-ai/DeepSeek-R1-0528,base_url=http://127.0.0.1:8000/v1/completions,num_concurrent=32), gen_kwargs: (None), limit: 100.0, num_fewshot: None, batch_size: 1
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.