[Attention] Support MTP with DCP#24997
[Attention] Support MTP with DCP#24997MatthewBonanni wants to merge 2 commits intovllm-project:mainfrom
Conversation
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
There was a problem hiding this comment.
Code Review
This pull request aims to enable Multi-Token Parallelism (MTP) with context parallelism by supporting query_len > 1 in the FlashAttention MLA backend. The changes involve removing previous restrictions and adding metadata for a custom causal mask.
I've found a critical issue where the new logic to compute query_base_positions in MLACommonMetadataBuilder is not being used by FlashAttnMLAMetadataBuilder because it overrides the _build_decode method. This will prevent the feature from working as intended. Please see my detailed comment.
| # Compute DCP query base positions if using DCP | ||
| query_base_positions = None | ||
|
|
||
| if self.dcp_world_size > 1: | ||
| query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] | ||
| query_base_positions = (seq_lens_cpu - query_lens).to( | ||
| seq_lens_device.device) | ||
|
|
||
| return MLACommonDecodeMetadata( | ||
| block_table=block_table_tensor, | ||
| seq_lens=seq_lens_device, | ||
| query_base_positions=query_base_positions, |
There was a problem hiding this comment.
This change correctly computes query_base_positions for DCP. However, FlashAttnMLAMetadataBuilder in vllm/v1/attention/backends/mla/flashattn_mla.py overrides _build_decode and does not call this base implementation. As a result, query_base_positions will be None for the FlashAttention MLA backend, and the MTP with context parallelism feature will not work correctly.
To fix this, you should move this logic to FlashAttnMLAMetadataBuilder._build_decode or refactor it so that FlashAttnMLAMetadataBuilder can reuse this logic. For example, you could add the logic to FlashAttnMLAMetadataBuilder._build_decode and pass query_base_positions to the FlashAttnMLADecodeMetadata constructor.
|
This pull request has merge conflicts that must be resolved before it can be |
|
superseded by #25049 |
Purpose
#24453 Added DCP support but did not support
query_len > 1. This PR, which depends on a corresponding FlashAttention PR (vllm-project/flash-attention#92), implements a custom causal mask to take advantage of the FlashAttention MLA backend's capability forquery_len > 1, thereby enabling MTP.Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.