Skip to content

[Model] Enable LoRA support for tower and connector in H2OVL#31696

Merged
jeejeelee merged 2 commits intovllm-project:mainfrom
shwetha-s-poojary:h2ovl-lora-mm-support
Mar 18, 2026
Merged

[Model] Enable LoRA support for tower and connector in H2OVL#31696
jeejeelee merged 2 commits intovllm-project:mainfrom
shwetha-s-poojary:h2ovl-lora-mm-support

Conversation

@shwetha-s-poojary
Copy link
Copy Markdown
Contributor

@shwetha-s-poojary shwetha-s-poojary commented Jan 5, 2026

Summary

Implemented get_num_mm_encoder_tokens() and get_num_mm_connector_tokens() in H2OVLChatModel to enable LoRA support for the vision tower and connector. The H2OVL connector uses a 1:1 token mapping with the encoder, so both methods return the same count as the input image tokens.

Related to: #31479

Technical details

H2OVL-Mississippi uses a ViT-MLP-LLM architecture with InternViT-300M as the vision encoder. Images are split into 448×448 tiles (1–6 per image), producing 256–1,590 visual tokens. Pixel shuffling reduces each tile to 256 tokens, and MSAC adjusts the number of tiles at multiple scales. The MLP connector does not reduce token count, and a resized 448×448 thumbnail is included for full-image context.
Ref: paper

Test

Added a test verifying that encoder tokens are multiples of 256, connector tokens match the encoder, and budget clipping works.

This is understanding of the model; I’m open to feedback or suggestions for additional tests to further verify its correctness.

@mergify
Copy link
Copy Markdown

mergify bot commented Jan 5, 2026

Documentation preview: https://vllm--31696.org.readthedocs.build/en/31696/

@mergify mergify bot added the documentation Improvements or additions to documentation label Jan 5, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables LoRA support for the vision tower and connector in H2OVL by implementing the necessary methods in H2OVLProcessingInfo. The changes look mostly correct, but there is a critical issue in get_mm_mapping where an incorrect name is used for the connector module. This will prevent LoRA from being applied to the connector. I've left a specific comment with a suggested fix.

@shwetha-s-poojary shwetha-s-poojary changed the title [Model] Enable LoRA support for tower and connector in H2OVL [WIP][Model] Enable LoRA support for tower and connector in H2OVL Jan 5, 2026
@github-actions
Copy link
Copy Markdown

github-actions bot commented Jan 5, 2026

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@shwetha-s-poojary shwetha-s-poojary changed the title [WIP][Model] Enable LoRA support for tower and connector in H2OVL [Model] Enable LoRA support for tower and connector in H2OVL Jan 5, 2026
@jeejeelee jeejeelee self-assigned this Jan 5, 2026
@shwetha-s-poojary shwetha-s-poojary changed the title [Model] Enable LoRA support for tower and connector in H2OVL [WIP][Model] Enable LoRA support for tower and connector in H2OVL Jan 5, 2026
@mergify mergify bot added the multi-modality Related to multi-modality (#4194) label Jan 6, 2026
@shwetha-s-poojary shwetha-s-poojary changed the title [WIP][Model] Enable LoRA support for tower and connector in H2OVL [Model] Enable LoRA support for tower and connector in H2OVL Jan 6, 2026
@chaunceyjiang
Copy link
Copy Markdown
Collaborator

Sorry, I mentioned the wrong issue number. Please ignore it.

Copy link
Copy Markdown

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.

Comment @cursor review or bugbot run to trigger another review on this PR

@shwetha-s-poojary
Copy link
Copy Markdown
Contributor Author

@jeejeelee can you take a look ?

@shwetha-s-poojary
Copy link
Copy Markdown
Contributor Author

@jeejeelee/ @DarkLight1337 can someone review this?

@linitra24
Copy link
Copy Markdown
Contributor

...
(EngineCore_DP0 pid=2705583)   File "/mnt/vllm_dir/my/vllm/vllm/model_executor/models/intern_vit.py", line 228, in forward
(EngineCore_DP0 pid=2705583)     qkv, _ = self.qkv(x)
(EngineCore_DP0 pid=2705583)              ^^^^^^^^^^^
(EngineCore_DP0 pid=2705583)   File "/mnt/vllm_dir/my/vllm/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
(EngineCore_DP0 pid=2705583)     return self._call_impl(*args, **kwargs)
(EngineCore_DP0 pid=2705583)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2705583)   File "/mnt/vllm_dir/my/vllm/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
(EngineCore_DP0 pid=2705583)     return forward_call(*args, **kwargs)
(EngineCore_DP0 pid=2705583)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2705583)   File "/mnt/vllm_dir/my/vllm/vllm/lora/layers/column_parallel_linear.py", line 136, in forward
(EngineCore_DP0 pid=2705583)     output_parallel = self.apply(input_, bias)
(EngineCore_DP0 pid=2705583)                       ^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2705583)   File "/mnt/vllm_dir/my/vllm/vllm/lora/layers/base_linear.py", line 134, in apply
(EngineCore_DP0 pid=2705583)     lora_output: torch.Tensor | None = self.punica_wrapper.add_lora_linear(
(EngineCore_DP0 pid=2705583)                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2705583)   File "/mnt/vllm_dir/my/vllm/vllm/lora/punica_wrapper/punica_gpu.py", line 230, in add_lora_linear
(EngineCore_DP0 pid=2705583)     self.add_shrink(
(EngineCore_DP0 pid=2705583)   File "/mnt/vllm_dir/my/vllm/vllm/lora/punica_wrapper/punica_gpu.py", line 101, in add_shrink
(EngineCore_DP0 pid=2705583)     lora_shrink(
(EngineCore_DP0 pid=2705583)   File "/mnt/vllm_dir/my/vllm/.venv/lib/python3.12/site-packages/torch/_ops.py", line 1255, in __call__
(EngineCore_DP0 pid=2705583)     return self._op(*args, **kwargs)
(EngineCore_DP0 pid=2705583)            ^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2705583)   File "/mnt/vllm_dir/my/vllm/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
(EngineCore_DP0 pid=2705583)     return func(*args, **kwargs)
(EngineCore_DP0 pid=2705583)            ^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2705583)   File "/mnt/vllm_dir/my/vllm/vllm/lora/ops/triton_ops/lora_shrink_op.py", line 180, in _lora_shrink
(EngineCore_DP0 pid=2705583)     assert token_lora_mapping.size(0) == M
(EngineCore_DP0 pid=2705583)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2705583) AssertionError

@shwetha-s-poojary With the current implementation, H2OVL fails to initialize properly when LoRA is enabled. I suspect this is because the vision encoder tokens and the visual tokens in the language model do not have a 1:1 mapping for this model.

@shwetha-s-poojary
Copy link
Copy Markdown
Contributor Author

...
(EngineCore_DP0 pid=2705583)   File "/mnt/vllm_dir/my/vllm/vllm/model_executor/models/intern_vit.py", line 228, in forward
(EngineCore_DP0 pid=2705583)     qkv, _ = self.qkv(x)
(EngineCore_DP0 pid=2705583)              ^^^^^^^^^^^
(EngineCore_DP0 pid=2705583)   File "/mnt/vllm_dir/my/vllm/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
(EngineCore_DP0 pid=2705583)     return self._call_impl(*args, **kwargs)
(EngineCore_DP0 pid=2705583)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2705583)   File "/mnt/vllm_dir/my/vllm/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
(EngineCore_DP0 pid=2705583)     return forward_call(*args, **kwargs)
(EngineCore_DP0 pid=2705583)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2705583)   File "/mnt/vllm_dir/my/vllm/vllm/lora/layers/column_parallel_linear.py", line 136, in forward
(EngineCore_DP0 pid=2705583)     output_parallel = self.apply(input_, bias)
(EngineCore_DP0 pid=2705583)                       ^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2705583)   File "/mnt/vllm_dir/my/vllm/vllm/lora/layers/base_linear.py", line 134, in apply
(EngineCore_DP0 pid=2705583)     lora_output: torch.Tensor | None = self.punica_wrapper.add_lora_linear(
(EngineCore_DP0 pid=2705583)                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2705583)   File "/mnt/vllm_dir/my/vllm/vllm/lora/punica_wrapper/punica_gpu.py", line 230, in add_lora_linear
(EngineCore_DP0 pid=2705583)     self.add_shrink(
(EngineCore_DP0 pid=2705583)   File "/mnt/vllm_dir/my/vllm/vllm/lora/punica_wrapper/punica_gpu.py", line 101, in add_shrink
(EngineCore_DP0 pid=2705583)     lora_shrink(
(EngineCore_DP0 pid=2705583)   File "/mnt/vllm_dir/my/vllm/.venv/lib/python3.12/site-packages/torch/_ops.py", line 1255, in __call__
(EngineCore_DP0 pid=2705583)     return self._op(*args, **kwargs)
(EngineCore_DP0 pid=2705583)            ^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2705583)   File "/mnt/vllm_dir/my/vllm/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
(EngineCore_DP0 pid=2705583)     return func(*args, **kwargs)
(EngineCore_DP0 pid=2705583)            ^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2705583)   File "/mnt/vllm_dir/my/vllm/vllm/lora/ops/triton_ops/lora_shrink_op.py", line 180, in _lora_shrink
(EngineCore_DP0 pid=2705583)     assert token_lora_mapping.size(0) == M
(EngineCore_DP0 pid=2705583)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2705583) AssertionError

@shwetha-s-poojary With the current implementation, H2OVL fails to initialize properly when LoRA is enabled. I suspect this is because the vision encoder tokens and the visual tokens in the language model do not have a 1:1 mapping for this model.

Thanks for catching this. I am working on fixing the token count mismatch for H2OVL.
Could you share how you’re testing this (model + LoRA setup)?

@linitra24
Copy link
Copy Markdown
Contributor

...
(EngineCore_DP0 pid=2705583)   File "/mnt/vllm_dir/my/vllm/vllm/model_executor/models/intern_vit.py", line 228, in forward
(EngineCore_DP0 pid=2705583)     qkv, _ = self.qkv(x)
(EngineCore_DP0 pid=2705583)              ^^^^^^^^^^^
(EngineCore_DP0 pid=2705583)   File "/mnt/vllm_dir/my/vllm/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
(EngineCore_DP0 pid=2705583)     return self._call_impl(*args, **kwargs)
(EngineCore_DP0 pid=2705583)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2705583)   File "/mnt/vllm_dir/my/vllm/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
(EngineCore_DP0 pid=2705583)     return forward_call(*args, **kwargs)
(EngineCore_DP0 pid=2705583)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2705583)   File "/mnt/vllm_dir/my/vllm/vllm/lora/layers/column_parallel_linear.py", line 136, in forward
(EngineCore_DP0 pid=2705583)     output_parallel = self.apply(input_, bias)
(EngineCore_DP0 pid=2705583)                       ^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2705583)   File "/mnt/vllm_dir/my/vllm/vllm/lora/layers/base_linear.py", line 134, in apply
(EngineCore_DP0 pid=2705583)     lora_output: torch.Tensor | None = self.punica_wrapper.add_lora_linear(
(EngineCore_DP0 pid=2705583)                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2705583)   File "/mnt/vllm_dir/my/vllm/vllm/lora/punica_wrapper/punica_gpu.py", line 230, in add_lora_linear
(EngineCore_DP0 pid=2705583)     self.add_shrink(
(EngineCore_DP0 pid=2705583)   File "/mnt/vllm_dir/my/vllm/vllm/lora/punica_wrapper/punica_gpu.py", line 101, in add_shrink
(EngineCore_DP0 pid=2705583)     lora_shrink(
(EngineCore_DP0 pid=2705583)   File "/mnt/vllm_dir/my/vllm/.venv/lib/python3.12/site-packages/torch/_ops.py", line 1255, in __call__
(EngineCore_DP0 pid=2705583)     return self._op(*args, **kwargs)
(EngineCore_DP0 pid=2705583)            ^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2705583)   File "/mnt/vllm_dir/my/vllm/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
(EngineCore_DP0 pid=2705583)     return func(*args, **kwargs)
(EngineCore_DP0 pid=2705583)            ^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2705583)   File "/mnt/vllm_dir/my/vllm/vllm/lora/ops/triton_ops/lora_shrink_op.py", line 180, in _lora_shrink
(EngineCore_DP0 pid=2705583)     assert token_lora_mapping.size(0) == M
(EngineCore_DP0 pid=2705583)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=2705583) AssertionError

@shwetha-s-poojary With the current implementation, H2OVL fails to initialize properly when LoRA is enabled. I suspect this is because the vision encoder tokens and the visual tokens in the language model do not have a 1:1 mapping for this model.

Thanks for catching this. I am working on fixing the token count mismatch for H2OVL. Could you share how you’re testing this (model + LoRA setup)?

Sure, you can use the following code to reproduce the issue:

from vllm import LLM, SamplingParams

# Sample prompts.
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

def main():
    # Create an LLM.
    llm = LLM(
        model="h2oai/h2ovl-mississippi-800m", 
        enable_lora=True,
        enforce_eager=True,
        enable_tower_connector_lora=True,
        mm_processor_cache_gb=0,
        trust_remote_code=True
    )

    outputs = llm.generate(
        prompts, 
        sampling_params,
    )
    # Print the outputs.
    print("\nGenerated Outputs:\n" + "-" * 60)
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt:    {prompt!r}")
        print(f"Output:    {generated_text!r}")
        print("-" * 60)


if __name__ == "__main__":
    main()

@shwetha-s-poojary
Copy link
Copy Markdown
Contributor Author

@linitra24 Hi all, apologies for the slow reply - feel free for someone else to pick this up if needed, I may not respond immediately.
IMO, the token-count mapping functions (get_num_mm_encoder_tokens and get_num_mm_connector_tokens) seem to preserve the token count for H2OVL correctly. The LoRA failures (AssertionError in token_lora_mapping) look like they stem from embedding reshaping in the connector MLP , rather than token count itself.
I’m not entirely sure of the best way to handle LoRA for this model , would really appreciate any guidance or suggestions on how to apply embedding-level LoRA correctly here.

@shwetha-s-poojary shwetha-s-poojary force-pushed the h2ovl-lora-mm-support branch 2 times, most recently from 08bf532 to 03a9bab Compare February 17, 2026 08:05
@linitra24
Copy link
Copy Markdown
Contributor

Hi @shwetha-s-poojary, apologies for the delayed response. Given that InternVL (#32397) has already added LoRA support and implemented the get_num_mm_*_tokens() methods, and since H2OVL shares the same InternVisionModel encoder, you should be able to simply follow the InternVL implementation.

@shwetha-s-poojary
Copy link
Copy Markdown
Contributor Author

@linitra24 I've followed the same and updated the PR.PTAL..

@shwetha-s-poojary shwetha-s-poojary force-pushed the h2ovl-lora-mm-support branch 2 times, most recently from 33751f7 to 167f2ea Compare March 2, 2026 05:22
Implement get_num_mm_encoder_tokens(), and get_num_mm_connector_tokens()
in H2OVLProcessingInfo

Signed-off-by: shwetha-s-poojary <shwetha.s-poojary@ibm.com>
@jeejeelee jeejeelee enabled auto-merge (squash) March 12, 2026 14:07
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 12, 2026
@jeejeelee jeejeelee merged commit cef1f30 into vllm-project:main Mar 18, 2026
54 checks passed
fxdawnn pushed a commit to fxdawnn/vllm that referenced this pull request Mar 19, 2026
…oject#31696)

Signed-off-by: shwetha-s-poojary <shwetha.s-poojary@ibm.com>
SouthWest7 pushed a commit to SouthWest7/vllm that referenced this pull request Mar 27, 2026
…oject#31696)

Signed-off-by: shwetha-s-poojary <shwetha.s-poojary@ibm.com>
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 27, 2026
…oject#31696)

Signed-off-by: shwetha-s-poojary <shwetha.s-poojary@ibm.com>
Monishver11 pushed a commit to Monishver11/vllm that referenced this pull request Mar 27, 2026
…oject#31696)

Signed-off-by: shwetha-s-poojary <shwetha.s-poojary@ibm.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
JiantaoXu pushed a commit to JiantaoXu/vllm that referenced this pull request Mar 28, 2026
…oject#31696)

Signed-off-by: shwetha-s-poojary <shwetha.s-poojary@ibm.com>
vrdn-23 pushed a commit to vrdn-23/vllm that referenced this pull request Mar 30, 2026
…oject#31696)

Signed-off-by: shwetha-s-poojary <shwetha.s-poojary@ibm.com>
Signed-off-by: Vinay Damodaran <vrdn@hey.com>
EricccYang pushed a commit to EricccYang/vllm that referenced this pull request Apr 1, 2026
…oject#31696)

Signed-off-by: shwetha-s-poojary <shwetha.s-poojary@ibm.com>
Signed-off-by: EricccYang <yangyang4991@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation multi-modality Related to multi-modality (#4194) ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants