Skip to content

[Models] Lfm2-VL Architecture#29191

Closed
paulpak58 wants to merge 3 commits intovllm-project:mainfrom
paulpak58:lfm2_vl
Closed

[Models] Lfm2-VL Architecture#29191
paulpak58 wants to merge 3 commits intovllm-project:mainfrom
paulpak58:lfm2_vl

Conversation

@paulpak58
Copy link
Copy Markdown
Contributor

@paulpak58 paulpak58 commented Nov 21, 2025

Purpose

LFM2-VL Implementation

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: Paul Pak <paulpak58@gmail.com>
@mergify
Copy link
Copy Markdown

mergify bot commented Nov 21, 2025

Documentation preview: https://vllm--29191.org.readthedocs.build/en/29191/

@mergify mergify bot added documentation Improvements or additions to documentation new-model Requests to new models labels Nov 21, 2025
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the Lfm2-VL model, including its architecture implementation, an example script for offline inference, and necessary registrations. A new Siglip2Model implementation is also added as it's a dependency for Lfm2-VL. The overall implementation is solid, but I've found a type hint mismatch in the new lfm2_vl.py file that should be corrected for code correctness and clarity.

spatial_shapes: torch.Tensor,
pixel_attention_mask: torch.Tensor,
num_patches: torch.Tensor,
) -> torch.Tensor:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The return type hint for image_pixels_to_features is torch.Tensor, but the function actually returns a list[torch.Tensor]. The caller of this function, _process_image_input, expects a list of tensors. Please update the type hint to list[torch.Tensor] to match the implementation and avoid potential type errors.

Suggested change
) -> torch.Tensor:
) -> list[torch.Tensor]:

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +390 to +394
return self.vision_model(
pixel_values=pixel_values,
# attention_mask=pixel_attention_mask,
spatial_shapes=spatial_shapes,
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Mask padded vision tokens in Siglip2 forward

In Siglip2Model.forward the processor-supplied pixel_attention_mask is accepted but never forwarded to the vision transformer (attention_mask is commented out). This means padded patch tokens introduced to equalize sequence length are never masked, yet they still carry bias and positional embeddings, so the encoder will mix these fake tokens into attention for any image smaller than the maximum patch budget—a common case—yielding incorrect visual embeddings. The mask needs to be propagated to the encoder or applied inside the attention layers to exclude padding.

Useful? React with 👍 / 👎.

self.vision_tower = Siglip2Model(
config=vision_config,
quant_config=quant_config,
prefix=f"{prefix}.vit",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like a typo in the prefix

quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.attn = MultiHeadAttention(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @Isotr0py @tjtanaa why does the existing SigLIP implementation not use MultiHeadAttention yet?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean siglip2navit? Because it needs cu_seq_lens to build attention mask, which hasn't been integrated into MultiHeadAttention.

We will consolidate it with MultiHeadAttention after the vision attention refactoring (waiting #27919)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

return attn_output


class Siglip2MLP(nn.Module):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the other layers, you can import directly from siglip2 file if they are the same

paulpak58 and others added 2 commits November 25, 2025 15:29
@github-actions
Copy link
Copy Markdown

This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you!

@github-actions github-actions bot added the stale Over 90 days of inactivity label Feb 27, 2026
@DarkLight1337
Copy link
Copy Markdown
Member

Closing as superseded by #31758

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation new-model Requests to new models stale Over 90 days of inactivity

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants