[Bugfix][Attention][DCP] Set reorder_batch_threshold back to 1 when using FlashMLA and enable DCP#27023
Conversation
There was a problem hiding this comment.
Code Review
This pull request addresses a bug that causes an assertion error when using the FlashMLA backend with Decode Context Parallelism (DCP) and a reorder_batch_threshold greater than 1. The fix correctly identifies this unsupported configuration and resets reorder_batch_threshold to 1, along with query_len_support, to prevent the crash. My review includes a suggestion to improve the maintainability of the implementation by using a class attribute for feature detection instead of checking the class name as a string. This will make the code more robust against future refactoring.
| if ( | ||
| self.dcp_world_size > 1 and self.reorder_batch_threshold > 1 | ||
| and self.__class__.__name__ != "FlashAttnMLAMetadataBuilder" | ||
| ): |
There was a problem hiding this comment.
Checking the class name as a string (self.__class__.__name__) is fragile and can lead to silent bugs if the class is ever renamed. A more robust and maintainable approach is to use a class attribute to indicate feature support.
You can define a class attribute in MLACommonMetadataBuilder and override it in the specific subclass that supports this feature.
For example:
- Add a new class attribute to
MLACommonMetadataBuilder(e.g., right afterreorder_batch_threshold):
class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
...
reorder_batch_threshold: int = 1
_supports_dcp_and_reorder: ClassVar[bool] = False
...- In the
FlashAttnMLAMetadataBuilderclass, override this attribute:
class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder):
...
_supports_dcp_and_reorder: ClassVar[bool] = True
...- Then, update the condition here to use this new attribute, which is more idiomatic and safer.
| if ( | |
| self.dcp_world_size > 1 and self.reorder_batch_threshold > 1 | |
| and self.__class__.__name__ != "FlashAttnMLAMetadataBuilder" | |
| ): | |
| if ( | |
| self.dcp_world_size > 1 and self.reorder_batch_threshold > 1 | |
| and not self._supports_dcp_and_reorder | |
| ): |
8d76fee to
4af3874
Compare
MatthewBonanni
left a comment
There was a problem hiding this comment.
In the following code from gpu_model_runner.py,
vllm/vllm/v1/worker/gpu_model_runner.py
Lines 586 to 596 in 00417f4
could we instead just change that assert to set the threshold? The metadata builder's threshold won't be updated, but what ultimately matters is the gpu model runner's threshold. i.e.
if self.reorder_batch_threshold is not None:
# NOTE(lucas): currently no backend supports the custom masking
# required for DCP with q_len > 1, so we assert here. Remove this
# assert once the custom mask is support is added to FA3.
if (
self.dcp_world_size > 1
and envs.VLLM_ATTENTION_BACKEND != "FLASH_ATTN_MLA"
):
logger.warning("This backend does not support DCP with q_len > 1. Setting reorder_batch_threshold to 1.")
self.reorder_batch_threshold = 1
… enable DCP Signed-off-by: FENP <32334296+FENP@users.noreply.github.com>
4af3874 to
a2d5ef0
Compare
|
Fixed by #28100 |
Purpose
For FlashMLA backend, #26541 set the default value of
reorder_batch_thresholdto 512.vllm/vllm/v1/attention/backends/mla/flashmla.py
Lines 71 to 76 in 00417f4
However, DCP support
reorder_batch_threshold> 1 only when FlashAttnMLA backend is used (#25049). Therefore, the following assertion error occurs when using the FlashMLA backend.vllm/vllm/v1/worker/gpu_model_runner.py
Lines 586 to 596 in 00417f4
This PR temporarily fixes the issue by setting
reorder_batch_threshold backto 1.Looking forward to DCP supporting
reorder_batch_threshold> 1 with FlashMLA in the future :).Test Plan
Test Result
main
this PR
INFO: 127.0.0.1:47140 - "POST /v1/chat/completions HTTP/1.1" 200 OKcc @minosfuture @MatthewBonanni @youkaichao @LucasWilkinson
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.