-
-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[Attention][DCP] Support DCP with query length > 1 (MTP) with FA3 #25049
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8cdbd20
45bb7e8
9c0176b
4cc05c6
a6efa96
ed6dcdd
efaf7fa
036357e
19a7f8c
3a5dcb2
56478b6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,7 +17,6 @@ | |
| get_flash_attn_version, | ||
| ) | ||
| from vllm.config import VllmConfig | ||
| from vllm.distributed.parallel_state import get_dcp_group | ||
| from vllm.logger import init_logger | ||
| from vllm.v1.attention.backends.mla.common import ( | ||
| MLACommonBackend, | ||
|
|
@@ -107,12 +106,6 @@ def __init__( | |
| # pre-allocated during capture. | ||
| self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH | ||
|
|
||
| # TODO(lucas): Until we add support for the DCP custom masking we need | ||
| # to restrict decodes to q_len == 1 when DCP is enabled. | ||
| self.reorder_batch_threshold = ( | ||
| 1 if get_dcp_group().world_size > 1 else self.reorder_batch_threshold | ||
| ) | ||
|
|
||
| def _schedule_decode( | ||
| self, num_reqs, cu_query_lens, max_query_len, seqlens, max_seq_len, causal | ||
| ): | ||
|
|
@@ -121,7 +114,7 @@ def _schedule_decode( | |
| batch_size=num_reqs, | ||
| max_seqlen_q=max_query_len, | ||
| max_seqlen_k=max_seq_len, | ||
| num_heads_q=self.num_heads, | ||
| num_heads_q=self.num_heads * self.dcp_world_size, | ||
| num_heads_kv=1, | ||
| headdim=self.mla_dims.qk_rope_head_dim, | ||
| cache_seqlens=seqlens, | ||
|
|
@@ -142,10 +135,11 @@ def _build_decode( | |
| query_start_loc_cpu: torch.Tensor, | ||
| query_start_loc_device: torch.Tensor, | ||
| num_decode_tokens: int, | ||
| dcp_tot_seq_lens_device: Optional[torch.Tensor], | ||
| ) -> FlashAttnMLADecodeMetadata: | ||
| query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] | ||
| max_query_len = query_lens_cpu.max().item() | ||
| max_seq_len = seq_lens_cpu.max().item() | ||
| max_seq_len = seq_lens_device.max().item() | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @minosfuture this introduces a CPU sync and will be problematic for performance, especially in the async scheduling case. It looks like it should not be too hard to ensure we obtain this from the cpu-side tensor in the DCP case too? |
||
|
|
||
| scheduler_metadata = self._schedule_decode( | ||
| num_reqs=seq_lens_cpu.numel(), | ||
|
|
@@ -188,6 +182,7 @@ def _build_decode( | |
| max_seq_len=max_seq_len, | ||
| scheduler_metadata=scheduler_metadata, | ||
| max_num_splits=max_num_splits, | ||
| dcp_tot_seq_lens=dcp_tot_seq_lens_device, | ||
| ) | ||
|
|
||
|
|
||
|
|
@@ -289,6 +284,9 @@ def _forward_decode( | |
| fa_version=3, # only version 3 is supported | ||
| scheduler_metadata=attn_metadata.decode.scheduler_metadata, | ||
| num_splits=attn_metadata.decode.max_num_splits, | ||
| cp_world_size=self.dcp_world_size, | ||
| cp_rank=self.dcp_rank, | ||
minosfuture marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| cp_tot_seqused_k=attn_metadata.decode.dcp_tot_seq_lens, | ||
| ) | ||
|
|
||
| if self.need_to_return_lse_for_decode: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -398,6 +398,10 @@ def __init__( | |
| self.max_num_reqs + 1, dtype=torch.int32 | ||
| ) | ||
| self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32) | ||
| if self.dcp_world_size > 1: | ||
| self.dcp_local_seq_lens = self._make_buffer( | ||
| self.max_num_reqs, dtype=torch.int32 | ||
| ) | ||
| # Because inputs_embeds may be bfloat16 and we don't need a numpy | ||
| # version of this tensor, avoid a RuntimeError by not creating a | ||
| # numpy buffer. | ||
|
|
@@ -581,7 +585,10 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: | |
| # NOTE(lucas): currently no backend supports the custom masking | ||
| # required for DCP with q_len > 1, so we assert here. Remove this | ||
| # assert once the custom mask is support is added to FA3. | ||
| if self.dcp_world_size > 1: | ||
| if ( | ||
| self.dcp_world_size > 1 | ||
| and envs.VLLM_ATTENTION_BACKEND != "FLASH_ATTN_MLA" | ||
| ): | ||
| assert self.reorder_batch_threshold == 1, ( | ||
| "DCP not support reorder_batch_threshold > 1 now." | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. since only flash_attn_mla support custom mask, we can't just remove this assert right now?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. make sense. I'll make a whitelist here for FA3 MLA |
||
| ) | ||
|
|
@@ -1335,6 +1342,9 @@ def _prepare_inputs( | |
| num_logits_indices=logits_indices.size(0), | ||
| causal=True, | ||
| encoder_seq_lens=encoder_seq_lens, | ||
| dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs] | ||
| if self.dcp_world_size > 1 | ||
| else None, | ||
| ) | ||
|
|
||
| if self.speculative_config and spec_decode_common_attn_metadata is None: | ||
|
|
@@ -3309,6 +3319,9 @@ def _dummy_run( | |
| kv_cache_group_id | ||
| ].slot_mapping.gpu[:num_tokens], | ||
| causal=True, | ||
| dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs] | ||
| if self.dcp_world_size > 1 | ||
| else None, | ||
| ) | ||
| for attn_group in self.attn_groups[kv_cache_group_id]: | ||
| if ubatch_slices is not None: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.