Skip to content

[Models]: Use MMEncoderAttention for MoonViT#31738

Merged
Isotr0py merged 9 commits intovllm-project:mainfrom
Isotr0py:moonvit-mha
Jan 6, 2026
Merged

[Models]: Use MMEncoderAttention for MoonViT#31738
Isotr0py merged 9 commits intovllm-project:mainfrom
Isotr0py:moonvit-mha

Conversation

@Isotr0py
Copy link
Copy Markdown
Member

@Isotr0py Isotr0py commented Jan 5, 2026

Purpose

Test Plan

python examples/offline_inference/vision_language.py -m kimi_vl

Test Result

tp=2, mm_encoder_tp_mode="weight"
(EngineCore_DP0 pid=1363147) (Worker_TP0 pid=1363153) INFO 01-06 00:29:50 [gpu_model_runner.py:3758] Starting to load model /home/mozf/LLM/Kimi-VL-A3B-Thinking-2506/...
(EngineCore_DP0 pid=1363147) (Worker_TP0 pid=1363153) INFO 01-06 00:29:50 [mm_encoder_attention.py:83] Using AttentionBackendEnum.FLASH_ATTN for MMEncoderAttention.
(EngineCore_DP0 pid=1363147) (Worker_TP1 pid=1363157) INFO 01-06 00:29:50 [mm_encoder_attention.py:83] Using AttentionBackendEnum.FLASH_ATTN for MMEncoderAttention.
(EngineCore_DP0 pid=1363147) (Worker_TP0 pid=1363153) INFO 01-06 00:29:51 [cuda.py:351] Using TRITON_MLA attention backend out of potential backends: ('TRITON_MLA',)
Loading safetensors checkpoint shards:   0% Completed | 0/7 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  14% Completed | 1/7 [00:00<00:05,  1.01it/s]
Loading safetensors checkpoint shards:  29% Completed | 2/7 [00:02<00:05,  1.14s/it]
Loading safetensors checkpoint shards:  43% Completed | 3/7 [00:03<00:04,  1.14s/it]
Loading safetensors checkpoint shards:  57% Completed | 4/7 [00:04<00:03,  1.15s/it]
Loading safetensors checkpoint shards:  71% Completed | 5/7 [00:05<00:02,  1.16s/it]
Loading safetensors checkpoint shards:  86% Completed | 6/7 [00:06<00:01,  1.00s/it]
Loading safetensors checkpoint shards: 100% Completed | 7/7 [00:07<00:00,  1.04s/it]
Loading safetensors checkpoint shards: 100% Completed | 7/7 [00:07<00:00,  1.07s/it]
(EngineCore_DP0 pid=1363147) (Worker_TP0 pid=1363153) 
(EngineCore_DP0 pid=1363147) (Worker_TP0 pid=1363153) INFO 01-06 00:29:58 [default_loader.py:308] Loading weights took 7.64 seconds
(EngineCore_DP0 pid=1363147) (Worker_TP0 pid=1363153) INFO 01-06 00:29:59 [gpu_model_runner.py:3855] Model loading took 15.5315 GiB memory and 7.926830 seconds
(EngineCore_DP0 pid=1363147) (Worker_TP0 pid=1363153) INFO 01-06 00:29:59 [gpu_model_runner.py:4665] Encoder cache will be initialized with a budget of 8192 tokens, and profiled with 5 image items of the maximum feature size.
(EngineCore_DP0 pid=1363147) (Worker_TP1 pid=1363157) INFO 01-06 00:29:59 [gpu_model_runner.py:4665] Encoder cache will be initialized with a budget of 8192 tokens, and profiled with 5 image items of the maximum feature size.
(EngineCore_DP0 pid=1363147) (Worker_TP0 pid=1363153) INFO 01-06 00:30:06 [backends.py:644] Using cache directory: /home/mozf/.cache/vllm/torch_compile_cache/31e1d32d9a/rank_0_0/backbone for vLLM's torch.compile
(EngineCore_DP0 pid=1363147) (Worker_TP0 pid=1363153) INFO 01-06 00:30:06 [backends.py:704] Dynamo bytecode transform time: 5.38 s
(EngineCore_DP0 pid=1363147) (Worker_TP0 pid=1363153) INFO 01-06 00:30:09 [backends.py:261] Cache the graph of compile range (1, 8192) for later use
(EngineCore_DP0 pid=1363147) (Worker_TP1 pid=1363157) INFO 01-06 00:30:09 [backends.py:261] Cache the graph of compile range (1, 8192) for later use
(EngineCore_DP0 pid=1363147) (Worker_TP1 pid=1363157) WARNING 01-06 00:30:09 [vllm.py:1447] Current vLLM config is not set.
(EngineCore_DP0 pid=1363147) (Worker_TP0 pid=1363153) WARNING 01-06 00:30:09 [vllm.py:1447] Current vLLM config is not set.
(EngineCore_DP0 pid=1363147) (Worker_TP1 pid=1363157) INFO 01-06 00:30:09 [scheduler.py:231] Chunked prefill is enabled with max_num_batched_tokens=2048.
(EngineCore_DP0 pid=1363147) (Worker_TP0 pid=1363153) INFO 01-06 00:30:09 [scheduler.py:231] Chunked prefill is enabled with max_num_batched_tokens=2048.
(EngineCore_DP0 pid=1363147) (Worker_TP0 pid=1363153) WARNING 01-06 00:30:10 [fused_moe.py:1054] Using default MoE config. Performance might be sub-optimal! Config file not found at /home/mozf/develop-projects/vllm/vllm/model_executor/layers/fused_moe/configs/E=64,N=704,device_name=NVIDIA_GeForce_RTX_3090.json
(EngineCore_DP0 pid=1363147) (Worker_TP0 pid=1363153) INFO 01-06 00:30:12 [backends.py:278] Compiling a graph for compile range (1, 8192) takes 3.03 s
(EngineCore_DP0 pid=1363147) (Worker_TP0 pid=1363153) INFO 01-06 00:30:12 [monitor.py:34] torch.compile takes 8.41 s in total
(EngineCore_DP0 pid=1363147) (Worker_TP0 pid=1363153) INFO 01-06 00:30:13 [gpu_worker.py:361] Available KV cache memory: 3.59 GiB
(EngineCore_DP0 pid=1363147) INFO 01-06 00:30:14 [kv_cache_utils.py:1305] GPU KV cache size: 124,096 tokens
(EngineCore_DP0 pid=1363147) INFO 01-06 00:30:14 [kv_cache_utils.py:1310] Maximum concurrency for 4,096 tokens per request: 30.30x
(EngineCore_DP0 pid=1363147) (Worker_TP1 pid=1363157) WARNING 01-06 00:30:14 [gpu_model_runner.py:5075] CUDAGraphMode.FULL_AND_PIECEWISE is not supported with TritonMLABackend backend (support: AttentionCGSupport.NEVER); setting cudagraph_mode=PIECEWISE because attention is compiled piecewise
(EngineCore_DP0 pid=1363147) (Worker_TP0 pid=1363153) WARNING 01-06 00:30:14 [gpu_model_runner.py:5075] CUDAGraphMode.FULL_AND_PIECEWISE is not supported with TritonMLABackend backend (support: AttentionCGSupport.NEVER); setting cudagraph_mode=PIECEWISE because attention is compiled piecewise
Capturing CUDA graphs (mixed prefill-decode, PIECEWISE):   0%|                                                                                                                 | 0/51 [00:00<?, ?it/s](EngineCore_DP0 pid=1363147) (Worker_TP0 pid=1363153) WARNING 01-06 00:30:14 [vllm.py:1447] Current vLLM config is not set.
(EngineCore_DP0 pid=1363147) (Worker_TP0 pid=1363153) INFO 01-06 00:30:14 [scheduler.py:231] Chunked prefill is enabled with max_num_batched_tokens=2048.
(EngineCore_DP0 pid=1363147) (Worker_TP1 pid=1363157) WARNING 01-06 00:30:14 [vllm.py:1447] Current vLLM config is not set.
(EngineCore_DP0 pid=1363147) (Worker_TP1 pid=1363157) INFO 01-06 00:30:14 [scheduler.py:231] Chunked prefill is enabled with max_num_batched_tokens=2048.
Capturing CUDA graphs (mixed prefill-decode, PIECEWISE): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 51/51 [00:09<00:00,  5.51it/s]
(EngineCore_DP0 pid=1363147) (Worker_TP0 pid=1363153) INFO 01-06 00:30:24 [gpu_model_runner.py:4806] Graph capturing finished in 10 secs, took 1.11 GiB
(EngineCore_DP0 pid=1363147) INFO 01-06 00:30:24 [core.py:273] init engine (profile, create kv cache, warmup model) took 24.86 seconds
...
Adding requests: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:03<00:00,  1.15it/s]
Processed prompts: 100%|████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:05<00:00,  1.49s/it, est. speed input: 1854.94 toks/s, output: 42.84 toks/s]
--------------------------------------------------
◁think▷So, let's analyze the image. The picture shows cherry blossoms in full bloom, with pink flowers and dark branches in the foreground. In the background, there's a tall, modern tower with a spherical observation deck and a lattice structure below. The sky is clear and blue. So the content includes cherry
--------------------------------------------------
◁think▷So, let's analyze the image. The picture shows cherry blossoms in full bloom, with pink flowers and dark branches in the foreground. Behind them, the Tokyo Skytree is visible against a clear blue sky. So the content includes cherry blossoms, the Tokyo Skytree, and the blue sky
--------------------------------------------------
◁think▷So, let's analyze the image. The picture shows cherry blossoms in full bloom, with pink flowers and dark branches in the foreground. In the background, there's a tall tower, which is the Tokyo Skytree, a well-known landmark in Tokyo, Japan. The sky is clear and blue,
--------------------------------------------------
◁think▷So, let's analyze the image. The picture shows cherry blossoms in full bloom, with pink flowers and dark branches in the foreground. In the background, there's a tall tower with a spherical observation deck and a lattice structure below. The sky is clear and blue. So the content includes cherry blossoms,
--------------------------------------------------
tp=2, mm_encoder_tp_mode="data"
(EngineCore_DP0 pid=1366030) (Worker_TP0 pid=1366036) INFO 01-06 00:31:44 [gpu_model_runner.py:3758] Starting to load model /home/mozf/LLM/Kimi-VL-A3B-Thinking-2506/...
(EngineCore_DP0 pid=1366030) (Worker_TP0 pid=1366036) INFO 01-06 00:31:45 [mm_encoder_attention.py:83] Using AttentionBackendEnum.FLASH_ATTN for MMEncoderAttention.
(EngineCore_DP0 pid=1366030) (Worker_TP1 pid=1366039) INFO 01-06 00:31:45 [mm_encoder_attention.py:83] Using AttentionBackendEnum.FLASH_ATTN for MMEncoderAttention.
(EngineCore_DP0 pid=1366030) (Worker_TP0 pid=1366036) INFO 01-06 00:31:45 [cuda.py:351] Using TRITON_MLA attention backend out of potential backends: ('TRITON_MLA',)
Loading safetensors checkpoint shards:   0% Completed | 0/7 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  14% Completed | 1/7 [00:01<00:06,  1.07s/it]
Loading safetensors checkpoint shards:  29% Completed | 2/7 [00:02<00:06,  1.20s/it]
Loading safetensors checkpoint shards:  43% Completed | 3/7 [00:03<00:04,  1.23s/it]
Loading safetensors checkpoint shards:  57% Completed | 4/7 [00:04<00:03,  1.21s/it]
Loading safetensors checkpoint shards:  71% Completed | 5/7 [00:06<00:02,  1.24s/it]
Loading safetensors checkpoint shards:  86% Completed | 6/7 [00:06<00:01,  1.07s/it]
Loading safetensors checkpoint shards: 100% Completed | 7/7 [00:07<00:00,  1.09s/it]
Loading safetensors checkpoint shards: 100% Completed | 7/7 [00:07<00:00,  1.14s/it]
(EngineCore_DP0 pid=1366030) (Worker_TP0 pid=1366036) 
(EngineCore_DP0 pid=1366030) (Worker_TP0 pid=1366036) INFO 01-06 00:31:53 [default_loader.py:308] Loading weights took 8.13 seconds
(EngineCore_DP0 pid=1366030) (Worker_TP0 pid=1366036) INFO 01-06 00:31:54 [gpu_model_runner.py:3855] Model loading took 15.9141 GiB memory and 8.405087 seconds
(EngineCore_DP0 pid=1366030) (Worker_TP1 pid=1366039) INFO 01-06 00:31:54 [gpu_model_runner.py:4665] Encoder cache will be initialized with a budget of 8192 tokens, and profiled with 5 image items of the maximum feature size.
(EngineCore_DP0 pid=1366030) (Worker_TP0 pid=1366036) INFO 01-06 00:31:54 [gpu_model_runner.py:4665] Encoder cache will be initialized with a budget of 8192 tokens, and profiled with 5 image items of the maximum feature size.
(EngineCore_DP0 pid=1366030) (Worker_TP0 pid=1366036) INFO 01-06 00:32:00 [backends.py:644] Using cache directory: /home/mozf/.cache/vllm/torch_compile_cache/31e1d32d9a/rank_0_0/backbone for vLLM's torch.compile
(EngineCore_DP0 pid=1366030) (Worker_TP0 pid=1366036) INFO 01-06 00:32:00 [backends.py:704] Dynamo bytecode transform time: 5.31 s
(EngineCore_DP0 pid=1366030) (Worker_TP1 pid=1366039) WARNING 01-06 00:32:03 [vllm.py:1447] Current vLLM config is not set.
(EngineCore_DP0 pid=1366030) (Worker_TP0 pid=1366036) WARNING 01-06 00:32:03 [vllm.py:1447] Current vLLM config is not set.
(EngineCore_DP0 pid=1366030) (Worker_TP0 pid=1366036) INFO 01-06 00:32:03 [scheduler.py:231] Chunked prefill is enabled with max_num_batched_tokens=2048.
(EngineCore_DP0 pid=1366030) (Worker_TP1 pid=1366039) INFO 01-06 00:32:03 [scheduler.py:231] Chunked prefill is enabled with max_num_batched_tokens=2048.
(EngineCore_DP0 pid=1366030) (Worker_TP0 pid=1366036) WARNING 01-06 00:32:04 [fused_moe.py:1054] Using default MoE config. Performance might be sub-optimal! Config file not found at /home/mozf/develop-projects/vllm/vllm/model_executor/layers/fused_moe/configs/E=64,N=704,device_name=NVIDIA_GeForce_RTX_3090.json
(EngineCore_DP0 pid=1366030) (Worker_TP0 pid=1366036) INFO 01-06 00:32:04 [backends.py:226] Directly load the compiled graph(s) for compile range (1, 8192) from the cache, took 1.484 s
(EngineCore_DP0 pid=1366030) (Worker_TP0 pid=1366036) INFO 01-06 00:32:04 [monitor.py:34] torch.compile takes 6.80 s in total
(EngineCore_DP0 pid=1366030) (Worker_TP1 pid=1366039) INFO 01-06 00:32:04 [backends.py:226] Directly load the compiled graph(s) for compile range (1, 8192) from the cache, took 1.493 s
(EngineCore_DP0 pid=1366030) (Worker_TP0 pid=1366036) INFO 01-06 00:32:06 [gpu_worker.py:361] Available KV cache memory: 3.21 GiB
(EngineCore_DP0 pid=1366030) INFO 01-06 00:32:06 [kv_cache_utils.py:1305] GPU KV cache size: 110,880 tokens
(EngineCore_DP0 pid=1366030) INFO 01-06 00:32:06 [kv_cache_utils.py:1310] Maximum concurrency for 4,096 tokens per request: 27.07x
(EngineCore_DP0 pid=1366030) (Worker_TP1 pid=1366039) WARNING 01-06 00:32:06 [gpu_model_runner.py:5075] CUDAGraphMode.FULL_AND_PIECEWISE is not supported with TritonMLABackend backend (support: AttentionCGSupport.NEVER); setting cudagraph_mode=PIECEWISE because attention is compiled piecewise
(EngineCore_DP0 pid=1366030) (Worker_TP0 pid=1366036) WARNING 01-06 00:32:06 [gpu_model_runner.py:5075] CUDAGraphMode.FULL_AND_PIECEWISE is not supported with TritonMLABackend backend (support: AttentionCGSupport.NEVER); setting cudagraph_mode=PIECEWISE because attention is compiled piecewise
(EngineCore_DP0 pid=1366030) (Worker_TP1 pid=1366039) WARNING 01-06 00:32:06 [vllm.py:1447] Current vLLM config is not set.
(EngineCore_DP0 pid=1366030) (Worker_TP1 pid=1366039) INFO 01-06 00:32:06 [scheduler.py:231] Chunked prefill is enabled with max_num_batched_tokens=2048.
Capturing CUDA graphs (mixed prefill-decode, PIECEWISE):   0%|                                                                                                                 | 0/51 [00:00<?, ?it/s](EngineCore_DP0 pid=1366030) (Worker_TP0 pid=1366036) WARNING 01-06 00:32:06 [vllm.py:1447] Current vLLM config is not set.
(EngineCore_DP0 pid=1366030) (Worker_TP0 pid=1366036) INFO 01-06 00:32:06 [scheduler.py:231] Chunked prefill is enabled with max_num_batched_tokens=2048.
Capturing CUDA graphs (mixed prefill-decode, PIECEWISE): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 51/51 [00:09<00:00,  5.58it/s]
(EngineCore_DP0 pid=1366030) (Worker_TP0 pid=1366036) INFO 01-06 00:32:16 [gpu_model_runner.py:4806] Graph capturing finished in 10 secs, took 1.11 GiB
(EngineCore_DP0 pid=1366030) INFO 01-06 00:32:16 [core.py:273] init engine (profile, create kv cache, warmup model) took 22.16 seconds
...
Adding requests: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:03<00:00,  1.16it/s]
Processed prompts: 100%|████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:03<00:00,  1.18it/s, est. speed input: 3273.86 toks/s, output: 75.61 toks/s]
--------------------------------------------------
◁think▷So, let's analyze the image. The picture shows cherry blossoms in full bloom, with pink flowers and dark branches in the foreground. In the background, there's a tall, modern tower with a spherical observation deck and a lattice structure below. The sky is clear and blue. So the content includes cherry
--------------------------------------------------
◁think▷So, let's analyze the image. The picture shows cherry blossoms in full bloom, with pink flowers and dark branches in the foreground. Behind them, the Tokyo Skytree is visible against a clear blue sky. So the content includes cherry blossoms, the Tokyo Skytree, and the blue sky
--------------------------------------------------
◁think▷So, let's analyze the image. The picture shows cherry blossoms in full bloom, with pink flowers and dark branches in the foreground. In the background, there's a tall tower, which is the Tokyo Skytree, a prominent landmark in Tokyo, Japan. The sky is clear and blue, creating
--------------------------------------------------
◁think▷So, let's analyze the image. The picture shows cherry blossoms in full bloom, with pink flowers and dark branches in the foreground. In the background, there's a tall tower with a spherical observation deck and a lattice structure below. The sky is clear and blue. So the content includes cherry blossoms,
--------------------------------------------------

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
@Isotr0py Isotr0py changed the title [Models]: Update ViT attention interface for MoonViT [Models]: Use MMEncoderAttention for MoonViT Jan 5, 2026
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the MoonViT model to use the standardized MMEncoderAttention interface, which is a good step towards improving code maintainability. It also adds support for data parallelism in the ViT encoder. However, I've found a critical issue where a signature change in MoonVitPretrainedModel breaks its usage in the KimiVL model, which will cause a runtime error. Additionally, there's another issue where the multimodal_config is not passed to MMEncoderAttention, preventing users from configuring the attention backend. Both issues need to be addressed.

I am having trouble creating individual review comments. Click here to see my feedback.

vllm/model_executor/models/moonvit.py (518-525)

critical

The signature of MoonVitPretrainedModel.__init__ has been changed to accept a multimodal_config object instead of a use_data_parallel boolean. However, the call site in vllm/model_executor/models/kimi_vl.py has not been updated accordingly and still passes a boolean. This will lead to an AttributeError at runtime when multimodal_config.mm_encoder_tp_mode is accessed within MoonVitEncoderLayer. To fix this, kimi_vl.py must be updated to pass the MultiModalConfig object.

vllm/model_executor/models/moonvit.py (362-366)

high

The multimodal_config parameter is available in MoonVitEncoderLayer.__init__ but is not passed to the MMEncoderAttention constructor. This prevents users from overriding the vision encoder's attention backend via the mm_encoder_attn_backend setting in MultiModalConfig. The multimodal_config should be forwarded to MMEncoderAttention to ensure correct backend selection and configurability.

        self.attn = MMEncoderAttention(
            num_heads=self.num_heads,
            head_size=self.hidden_size_per_attention_head,
            prefix=f"{prefix}.attn",
            multimodal_config=multimodal_config,
        )

@Isotr0py
Copy link
Copy Markdown
Member Author

Isotr0py commented Jan 5, 2026

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the MoonViT model to use the centralized MMEncoderAttention layer and tensor-parallel linear layers (ColumnParallelLinear, RowParallelLinear, QKVParallelLinear). This is a great improvement for code consistency and maintainability within the vLLM project. The changes also introduce data parallelism support for the Vision Transformer (ViT) part of the model. The implementation looks solid, but I've found one minor issue regarding explicitness in the code that should be addressed.

ywang96 and others added 3 commits January 5, 2026 10:42
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Roger Wang <hey@rogerw.io>
Signed-off-by: h100 <h100@inferact.ai>
@Isotr0py Isotr0py enabled auto-merge (squash) January 6, 2026 01:05
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jan 6, 2026
@Isotr0py Isotr0py merged commit 7101e08 into vllm-project:main Jan 6, 2026
52 checks passed
@Isotr0py Isotr0py deleted the moonvit-mha branch January 6, 2026 13:04
LucasWilkinson pushed a commit to neuralmagic/vllm that referenced this pull request Jan 6, 2026
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Roger Wang <hey@rogerw.io>
Signed-off-by: h100 <h100@inferact.ai>
Co-authored-by: Roger Wang <hey@rogerw.io>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: h100 <h100@inferact.ai>
yugong333 pushed a commit to yugong333/vllm that referenced this pull request Jan 9, 2026
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Roger Wang <hey@rogerw.io>
Signed-off-by: h100 <h100@inferact.ai>
Co-authored-by: Roger Wang <hey@rogerw.io>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: h100 <h100@inferact.ai>
akh64bit pushed a commit to akh64bit/vllm that referenced this pull request Jan 16, 2026
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Roger Wang <hey@rogerw.io>
Signed-off-by: h100 <h100@inferact.ai>
Co-authored-by: Roger Wang <hey@rogerw.io>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: h100 <h100@inferact.ai>
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Roger Wang <hey@rogerw.io>
Signed-off-by: h100 <h100@inferact.ai>
Co-authored-by: Roger Wang <hey@rogerw.io>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: h100 <h100@inferact.ai>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
njhill pushed a commit that referenced this pull request Jan 27, 2026
- Replace custom multihead_attention/eager_attention with MMEncoderAttention
- Add tensor parallel support using QKVParallelLinear and RowParallelLinear
- MLP2 now uses ColumnParallelLinear/RowParallelLinear with TP/DP support
- Pass multimodal_config through MoonViT3dPretrainedModel to encoder layers
- Rename vision_tower_forward_auto to vision_tower_forward

This aligns with PR #31738 pattern for unified multi-platform attention backends.

Signed-off-by: wanglinian <wanglinian@stu.pku.edu.cn>
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Roger Wang <hey@rogerw.io>
Signed-off-by: h100 <h100@inferact.ai>
Co-authored-by: Roger Wang <hey@rogerw.io>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: h100 <h100@inferact.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants