[Models]: Use MMEncoderAttention for MoonViT#31738
[Models]: Use MMEncoderAttention for MoonViT#31738Isotr0py merged 9 commits intovllm-project:mainfrom
MMEncoderAttention for MoonViT#31738Conversation
MMEncoderAttention for MoonViT
There was a problem hiding this comment.
Code Review
This pull request refactors the MoonViT model to use the standardized MMEncoderAttention interface, which is a good step towards improving code maintainability. It also adds support for data parallelism in the ViT encoder. However, I've found a critical issue where a signature change in MoonVitPretrainedModel breaks its usage in the KimiVL model, which will cause a runtime error. Additionally, there's another issue where the multimodal_config is not passed to MMEncoderAttention, preventing users from configuring the attention backend. Both issues need to be addressed.
I am having trouble creating individual review comments. Click here to see my feedback.
vllm/model_executor/models/moonvit.py (518-525)
The signature of MoonVitPretrainedModel.__init__ has been changed to accept a multimodal_config object instead of a use_data_parallel boolean. However, the call site in vllm/model_executor/models/kimi_vl.py has not been updated accordingly and still passes a boolean. This will lead to an AttributeError at runtime when multimodal_config.mm_encoder_tp_mode is accessed within MoonVitEncoderLayer. To fix this, kimi_vl.py must be updated to pass the MultiModalConfig object.
vllm/model_executor/models/moonvit.py (362-366)
The multimodal_config parameter is available in MoonVitEncoderLayer.__init__ but is not passed to the MMEncoderAttention constructor. This prevents users from overriding the vision encoder's attention backend via the mm_encoder_attn_backend setting in MultiModalConfig. The multimodal_config should be forwarded to MMEncoderAttention to ensure correct backend selection and configurability.
self.attn = MMEncoderAttention(
num_heads=self.num_heads,
head_size=self.hidden_size_per_attention_head,
prefix=f"{prefix}.attn",
multimodal_config=multimodal_config,
)
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request refactors the MoonViT model to use the centralized MMEncoderAttention layer and tensor-parallel linear layers (ColumnParallelLinear, RowParallelLinear, QKVParallelLinear). This is a great improvement for code consistency and maintainability within the vLLM project. The changes also introduce data parallelism support for the Vision Transformer (ViT) part of the model. The implementation looks solid, but I've found one minor issue regarding explicitness in the code that should be addressed.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: Roger Wang <hey@rogerw.io>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: Roger Wang <hey@rogerw.io> Signed-off-by: h100 <h100@inferact.ai> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: h100 <h100@inferact.ai>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: Roger Wang <hey@rogerw.io> Signed-off-by: h100 <h100@inferact.ai> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: h100 <h100@inferact.ai>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: Roger Wang <hey@rogerw.io> Signed-off-by: h100 <h100@inferact.ai> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: h100 <h100@inferact.ai>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: Roger Wang <hey@rogerw.io> Signed-off-by: h100 <h100@inferact.ai> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: h100 <h100@inferact.ai> Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
- Replace custom multihead_attention/eager_attention with MMEncoderAttention - Add tensor parallel support using QKVParallelLinear and RowParallelLinear - MLP2 now uses ColumnParallelLinear/RowParallelLinear with TP/DP support - Pass multimodal_config through MoonViT3dPretrainedModel to encoder layers - Rename vision_tower_forward_auto to vision_tower_forward This aligns with PR #31738 pattern for unified multi-platform attention backends. Signed-off-by: wanglinian <wanglinian@stu.pku.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: Roger Wang <hey@rogerw.io> Signed-off-by: h100 <h100@inferact.ai> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: h100 <h100@inferact.ai>
Purpose
Test Plan
Test Result
tp=2, mm_encoder_tp_mode="weight"
tp=2, mm_encoder_tp_mode="data"
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.