Skip to content

[Core] Initialize LoRA support for tower and connector in multi-modal models#26674

Merged
vllm-bot merged 91 commits intovllm-project:mainfrom
jeejeelee:mlm-full-lora-support
Dec 26, 2025
Merged

[Core] Initialize LoRA support for tower and connector in multi-modal models#26674
vllm-bot merged 91 commits intovllm-project:mainfrom
jeejeelee:mlm-full-lora-support

Conversation

@jeejeelee
Copy link
Copy Markdown
Collaborator

@jeejeelee jeejeelee commented Oct 13, 2025

Purpose

FIX #26422
FIX #27916
FIX #29635

Summary

This PR enables LoRA support for the tower modules and connector in multi-modal models. Previously, a single global Punica wrapper was used, which was insufficient because the input structures for the tower/connector differ from those of the language model.

Key Changes

To address this, the following changes were implemented:

  • Decoupled Punica Wrappers: Instead of sharing a single wrapper, separate Punica wrappers are now instantiated for the language model, tower model , and connector respectively.
  • New ProcessingInfo Methods: Added get_num_mm_encoder_tokens and get_num_mm_connector_tokens to the ProcessingInfo interface. These are used to calculate the actual input token counts for the vision tower and connector.

Why add get_num_mm_*_tokens?

The number of multi-modal tokens represented in the language model does not necessarily match the input length required by the linear layers in the vision tower or connector. Since the lora_mapping requires the precise input token length prior to activation, these helper functions are necessary to bridge the discrepancy and calculate the correct lengths.

Implementation Details

Currently, this PR implements the logic for the Qwen2/2.5/3 VL and Idefics3. Below is an example of the implementation:

def get_num_mm_encoder_tokens(
    self,
    num_image_tokens: int,
) -> int:
    hf_config = self.get_hf_config()
    vision_config = hf_config.vision_config
    merge_size = vision_config.spatial_merge_size

    return num_image_tokens * merge_size**2

Note: To support LoRA for the tower and connector of other multi-modal models, these two functions must be implemented in their respective classes.

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

linitra24 and others added 11 commits May 22, 2025 00:31
Signed-off-by: bk-201 <joy25810@foxmail.com>
Signed-off-by: bk-201 <joy25810@foxmail.com>
Signed-off-by: bk-201 <joy25810@foxmail.com>
Signed-off-by: bk-201 <joy25810@foxmail.com>
Signed-off-by: bk-201 <joy25810@foxmail.com>
Signed-off-by: bk-201 <joy25810@foxmail.com>
Signed-off-by: bk-201 <joy25810@foxmail.com>
Signed-off-by: bk-201 <joy25810@foxmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
@jeejeelee jeejeelee marked this pull request as draft October 13, 2025 03:32
@mergify mergify bot added multi-modality Related to multi-modality (#4194) qwen Related to Qwen models v1 labels Oct 13, 2025
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces LoRA support for tower modules in multi-modal models by creating distinct punica_wrapper instances for different model components, such as the vision tower and the language model, and dispatching LoRA operations accordingly. The changes are well-structured, primarily modifying vllm/lora/models.py for the core logic and updating other files for necessary plumbing. My review identifies a few critical and high-severity issues in the core logic, including hardcoded values that should be dynamic, potentially incorrect use of configuration parameters between encoder and decoder, and unsafe list access that could lead to runtime errors. Addressing these points will improve the robustness and correctness of the implementation.

Comment on lines +512 to +530
self.mm_punica_wrapper_mapping[
self.mm_mapping.tower_model[0]
].update_metadata(
mapping,
self.lora_index_to_id,
self.lora_slots + 1,
self.vocab_size,
self.lora_config.lora_extra_vocab_size,
)
else:
self.mm_punica_wrapper_mapping[
self.mm_mapping.language_model[0]
].update_metadata(
mapping,
self.lora_index_to_id,
self.lora_slots + 1,
self.vocab_size,
self.lora_config.lora_extra_vocab_size,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The code accesses self.mm_mapping.tower_model[0] (line 513) and self.mm_mapping.language_model[0] (line 523) without checking if the lists are empty. If get_mm_mapping() returns a MultiModelKeys object with an empty tower_model or language_model list, this will raise an IndexError, causing a crash. A similar issue exists in __init__ on line 414. It's safer to add checks to ensure these lists are not empty before accessing their elements.

For example:

if not self.mm_mapping.tower_model:
    raise ValueError("No tower model found in mm_mapping for LoRA.")
tower_model_name = self.mm_mapping.tower_model[0]
self.mm_punica_wrapper_mapping[tower_model_name].update_metadata(...)

self.mm_config = model_config.multimodal_config
# limit_per_prompt: int = max(
# self.info.get_allowed_mm_limits().values())
limit_per_prompt = 5 # TODO
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The limit_per_prompt is hardcoded to 5. This could lead to unexpected behavior or limitations if not aligned with the actual model's capabilities or user's expectations. It would be better to derive this value from the model configuration, as suggested by the commented-out code, to ensure it's dynamically and correctly set. Using max(values or [0]) would also prevent a ValueError if get_allowed_mm_limits() returns an empty mapping.

Suggested change
limit_per_prompt = 5 # TODO
limit_per_prompt: int = max(self.info.get_allowed_mm_limits().values() or [0])

Comment on lines +403 to +411
self.mm_punica_wrapper_mapping = {
name: get_punica_wrapper(
self.info.get_num_mm_encoder_tokens(max_num_batched_tokens),
max_batches=self.max_num_seqs * limit_per_prompt,
device=self.device,
max_loras=self.lora_config.max_loras,
)
for name in self.mm_mapping.tower_model
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The max_num_batched_tokens parameter, which is typically configured for the decoder, is being used to determine the number of tokens for the multi-modal encoder via self.info.get_num_mm_encoder_tokens(max_num_batched_tokens). This could be incorrect as the encoder might have a different token budget. The commented-out line # max_num_batched_tokens = encoder_budget suggests that a separate encoder_budget should be used. Using the decoder's token budget for the encoder could lead to misconfiguration and potential runtime errors.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Comment on lines +639 to +709
be filtered out.
"""
if self.supports_mm:
module_mapping: MultiModelKeys = self.model.get_mm_mapping()
prefix_lst = module_mapping.connector + module_mapping.tower_model
return any([module_name.startswith(prefix) for prefix in prefix_lst])
prefix_lst = self.mm_mapping.connector + self.mm_mapping.tower_model
if self.supports_mm_lora:
return self._get_mm_punica_wrapper(module_name) is None
else:
return any([module_name.startswith(prefix) for prefix in prefix_lst])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Access undefined multimodal mapping when LoRA tower unsupported

When LoRA is enabled for a multimodal model that does not yet implement tower LoRA (i.e. supports_mm is true but supports_mm_lora is false because the processing info lacks get_num_mm_encoder_tokens), _filter_unsupported_mm_module still dereferences self.mm_mapping to build the prefix list. However self.mm_mapping is only created inside __init__ when supports_mm_lora is true, so any such model will raise an AttributeError during adapter registration, breaking existing multimodal models that only supported language-model LoRA before this change. The mapping should be fetched regardless of supports_mm_lora or lazily inside the filter.

Useful? React with 👍 / 👎.

prashanth058 and others added 2 commits November 20, 2025 15:04
Signed-off-by: prashanth058 <prashanth.dannamaneni@uipath.com>
@mergify mergify bot added the tpu Related to Google TPUs label Nov 20, 2025
@mergify mergify bot added the needs-rebase label Dec 24, 2025
@mergify mergify bot removed the needs-rebase label Dec 24, 2025
@mergify
Copy link
Copy Markdown

mergify bot commented Dec 24, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @jeejeelee.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Dec 24, 2025
Signed-off-by: bk-201 <joy25810@foxmail.com>
Signed-off-by: bk-201 <joy25810@foxmail.com>
@mergify mergify bot removed the needs-rebase label Dec 24, 2025
@Isotr0py
Copy link
Copy Markdown
Member

For now, perhaps we can disable the multi-modal processor cache when enable_tower_connector_lora is active? Could we address this more thoroughly in a follow-up PR?

Sure, feel free to disable the cache when enabling mm lora! We can address it in following PRs.

Copy link
Copy Markdown
Member

@Isotr0py Isotr0py left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM now!

@vllm-bot vllm-bot merged commit ce1eafd into vllm-project:main Dec 26, 2025
60 of 62 checks passed
patrickvonplaten pushed a commit to patrickvonplaten/vllm that referenced this pull request Dec 27, 2025
… models (vllm-project#26674)

Signed-off-by: bk-201 <joy25810@foxmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: prashanth058 <prashanth.dannamaneni@uipath.com>
Co-authored-by: bk-201 <joy25810@foxmail.com>
Co-authored-by: prashanth058 <prashanth.dannamaneni@uipath.com>
Co-authored-by: Anexdeus <5142168@mail.ru>
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
@jeejeelee jeejeelee deleted the mlm-full-lora-support branch December 29, 2025 01:28
yiliu30 pushed a commit to yiliu30/vllm-fork that referenced this pull request Dec 30, 2025
… models (vllm-project#26674)

Signed-off-by: bk-201 <joy25810@foxmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: prashanth058 <prashanth.dannamaneni@uipath.com>
Co-authored-by: bk-201 <joy25810@foxmail.com>
Co-authored-by: prashanth058 <prashanth.dannamaneni@uipath.com>
Co-authored-by: Anexdeus <5142168@mail.ru>
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
… models (vllm-project#26674)

Signed-off-by: bk-201 <joy25810@foxmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: prashanth058 <prashanth.dannamaneni@uipath.com>
Co-authored-by: bk-201 <joy25810@foxmail.com>
Co-authored-by: prashanth058 <prashanth.dannamaneni@uipath.com>
Co-authored-by: Anexdeus <5142168@mail.ru>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
… models (vllm-project#26674)

Signed-off-by: bk-201 <joy25810@foxmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: prashanth058 <prashanth.dannamaneni@uipath.com>
Co-authored-by: bk-201 <joy25810@foxmail.com>
Co-authored-by: prashanth058 <prashanth.dannamaneni@uipath.com>
Co-authored-by: Anexdeus <5142168@mail.ru>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation multi-modality Related to multi-modality (#4194) qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm tpu Related to Google TPUs v1

Projects

None yet

9 participants