Skip to content

[Feature] Add LoRA tower/connector support for Llama 4 Vision (mllama4)#35147

Merged
vllm-bot merged 2 commits intovllm-project:mainfrom
dorhuri123:add-llama4-vision-lora-tower-connector
Feb 24, 2026
Merged

[Feature] Add LoRA tower/connector support for Llama 4 Vision (mllama4)#35147
vllm-bot merged 2 commits intovllm-project:mainfrom
dorhuri123:add-llama4-vision-lora-tower-connector

Conversation

@dorhuri123
Copy link
Copy Markdown
Contributor

@dorhuri123 dorhuri123 commented Feb 23, 2026

Purpose

Enable LoRA adapters for the vision tower and connector of Llama 4 Vision (Llama4ForConditionalGeneration / mllama4.py), as part of #31479.

Previously, LoRA could only be applied to the language model layers. With this change, --enable-tower-connector-lora also applies LoRA to:

  • Tower — vision encoder attention layers (vision_model.model.layers.*.self_attn)
  • Connector — vision adapter MLP (vision_model.vision_adapter.mlp) and multi-modal projector (multi_modal_projector)

Changes (1 file, 23 lines)

  1. get_mm_mapping() — Updated connector from a single string to a list that includes both multi_modal_projector. and vision_model.vision_adapter.. The LoRA manager uses longest-prefix matching, so vision_model.vision_adapter.* modules correctly map to the connector wrapper (not tower).

  2. get_num_mm_encoder_tokens() — Converts LM-level image token count back to vision encoder token count. The encoder processes (image_size/patch_size)² + 1 tokens per chunk (raw patches + CLS token), while the LM sees patches_per_chunk tokens (post pixel-shuffle, fewer).

  3. get_num_mm_connector_tokens() — Converts encoder token count to connector token count (post pixel-shuffle). The connector (vision_adapter MLP + multi_modal_projector) processes the reduced token count.

Token flow

Image → Patch Embedding → (image_size/patch_size)² = 1296 patches
  → Add CLS token → 1297 tokens
  → Vision Encoder (tower) → 1297 tokens        ← get_num_mm_encoder_tokens
  → Remove CLS → 1296 tokens
  → Pixel Shuffle (1/4 downsample) → 324 tokens
  → Vision Adapter MLP (connector) → 324 tokens  ← get_num_mm_connector_tokens
  → Multi-Modal Projector (connector) → 324 tokens
  → Language Model → 324 tokens per chunk

Values shown for Llama 4 Scout (image_size=504, patch_size=14, pixel_shuffle_ratio=0.5).

Test Plan

Tested on 4x NVIDIA H100 80GB with Llama 4 Scout 17B-16E (nvidia/Llama-4-Scout-17B-16E-Instruct-FP8).

1. Create a test LoRA adapter targeting tower + connector + LM layers

from transformers import AutoConfig, Llama4ForConditionalGeneration
from accelerate import init_empty_weights
from peft import LoraConfig, get_peft_model

config = AutoConfig.from_pretrained("meta-llama/Llama-4-Scout-17B-16E-Instruct")
with init_empty_weights():
    model = Llama4ForConditionalGeneration(config)

lora_config = LoraConfig(
    r=8, lora_alpha=16,
    target_modules=["q_proj", "k_proj", "v_proj", "fc1", "fc2", "linear_1"],
    task_type="CAUSAL_LM",
)
peft_model = get_peft_model(model, lora_config)
# Save adapter (~28MB) with safetensors

Verified LoRA coverage across all 4 module groups: tower (vision encoder attention), connector (vision adapter MLP), connector (multi-modal projector), and language model.

2. Serve with tower/connector LoRA enabled

python -m vllm.entrypoints.openai.api_server \
  --model nvidia/Llama-4-Scout-17B-16E-Instruct-FP8 \
  --enable-lora \
  --lora-modules test-lora=/tmp/llama4-vision-test-lora \
  --enable-tower-connector-lora \
  --max-lora-rank 8 \
  --tensor-parallel-size 4 \
  --max-model-len 4096 \
  --enforce-eager \
  --gpu-memory-utilization 0.95 \
  --port 8000

Server started successfully with tower/connector LoRA active:

LoRA for the tower and connector of multimodal models is experimental...
Loaded new LoRA adapter: name 'test-lora', path '/tmp/llama4-vision-test-lora'
Starting vLLM API server 0 on http://0.0.0.0:8000

3. Run inference tests

Three tests via the OpenAI-compatible API:

  • Test 1: Vision inference without LoRA (baseline)
  • Test 2: Vision inference with LoRA (tower + connector + LM)
  • Test 3: Text-only inference with LoRA

Test Result

All 3 tests passed. The test image is a PNG with four colored dice:

image
Test Model Status Output
1 — Vision (baseline) base model PASS "The image features four dice, each with a distinct color: blue, red, green, and yellow. The dice are depicted in mid-air, as if they have been tossed or thrown, showcasing their reflective surfaces and white dots that represent numbers."
2 — Vision (with LoRA) test-lora PASS "The image depicts four dice suspended in mid-air, each with a distinct color: red, blue, green, and yellow. The dice are positioned against a white background, which serves to accentuate their vibrant colors. Notably, the red die is prominently displayed at the center of the image, while the other three dice appear slightly blurred and are arranged around it."
3 — Text (with LoRA) test-lora PASS "4" (Q: What is 2+2?)
  • All tests passed: YES
  • Outputs differ between baseline and LoRA: YES (expected with stochastic sampling; adapter is randomly initialized, not trained)
  • No "will be ignored" warnings for LoRA modules
  • No token count mismatch errors

Reference pattern

This follows the same approach as InternVL2 (#32397) which also has pixel shuffle with CLS token handling, and Qwen2VL where the merger is a sub-prefix of the vision tower.

@mergify mergify bot added the llama Related to Llama models label Feb 23, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request successfully enables LoRA support for the vision tower and connector of the Llama 4 Vision model. The implementation correctly defines the module mapping for LoRA targeting, using longest-prefix matching to distinguish between tower and connector components. Additionally, it provides the necessary methods to calculate token counts for the vision encoder and connector, accounting for the pixel shuffle and CLS token handling specific to the Llama 4 Vision architecture. The logic aligns with existing patterns in vLLM for multimodal LoRA support.

Implement get_num_mm_encoder_tokens() and get_num_mm_connector_tokens()
for Llama4ForConditionalGeneration so LoRA adapters can be applied to
the vision encoder (tower) and connector modules.

Also update get_mm_mapping() to separate vision_model.vision_adapter
into the connector prefix, since the adapter MLP processes post-pixel-shuffle
tokens (different count from the encoder layers).

Signed-off-by: dorhuri123 <dor.huri1@live.biu.ac.il>
Copy link
Copy Markdown
Collaborator

@jeejeelee jeejeelee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume you have tested this locally.

@jeejeelee jeejeelee added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 24, 2026
@dorhuri123
Copy link
Copy Markdown
Contributor Author

Yes! Tested end-to-end on 4x H100 80GB with nvidia/Llama-4-Scout-17B-16E-Instruct-FP8 and a LoRA adapter targeting tower, connector, and LM layers. All 3 inference tests passed — baseline and LoRA both produce valid outputs. Full test results are in the PR description.

@vllm-bot vllm-bot merged commit 012dee9 into vllm-project:main Feb 24, 2026
53 of 56 checks passed
tom-zju pushed a commit to tom-zju/vllm that referenced this pull request Feb 26, 2026
…4) (vllm-project#35147)

Signed-off-by: dorhuri123 <dor.huri1@live.biu.ac.il>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
llsj14 pushed a commit to llsj14/vllm that referenced this pull request Mar 1, 2026
…4) (vllm-project#35147)

Signed-off-by: dorhuri123 <dor.huri1@live.biu.ac.il>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
tunglinwood pushed a commit to tunglinwood/vllm that referenced this pull request Mar 4, 2026
…4) (vllm-project#35147)

Signed-off-by: dorhuri123 <dor.huri1@live.biu.ac.il>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
askliar pushed a commit to askliar/vllm that referenced this pull request Mar 9, 2026
…4) (vllm-project#35147)

Signed-off-by: dorhuri123 <dor.huri1@live.biu.ac.il>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Andrii Skliar <askliar@nvidia.com>
Copilot AI pushed a commit to machov/vllm that referenced this pull request Mar 10, 2026
…4) (vllm-project#35147)

Signed-off-by: dorhuri123 <dor.huri1@live.biu.ac.il>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
EricccYang pushed a commit to EricccYang/vllm that referenced this pull request Apr 1, 2026
…4) (vllm-project#35147)

Signed-off-by: dorhuri123 <dor.huri1@live.biu.ac.il>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: EricccYang <yangyang4991@gmail.com>
liuchenbing2026 pushed a commit to liuchenbing2026/vllm that referenced this pull request Apr 4, 2026
…4) (vllm-project#35147)

Signed-off-by: dorhuri123 <dor.huri1@live.biu.ac.il>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

llama Related to Llama models ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants