Skip to content

[Model] Use mm_position to compute mrope positions for Qwen2-VL/2.5-VL#32126

Merged
Isotr0py merged 4 commits intovllm-project:mainfrom
YunzhuLu:refactor-mrope-position-computation
Jan 13, 2026
Merged

[Model] Use mm_position to compute mrope positions for Qwen2-VL/2.5-VL#32126
Isotr0py merged 4 commits intovllm-project:mainfrom
YunzhuLu:refactor-mrope-position-computation

Conversation

@YunzhuLu
Copy link
Copy Markdown
Contributor

@YunzhuLu YunzhuLu commented Jan 11, 2026

Purpose

Followup to #28399 and #28730. This PR optimizes the M-RoPE position computation for both Qwen2-VL and Qwen2.5-VL.

  • Leveraging the pre-calculated mm_features to compute M-RoPE positions.
  • Simplifying index generation with NumPy-based vectorization.

Test Plan

VLLM_WORKER_MULTIPROC_METHOD=spawn lm_eval --model vllm-vlm --model_args "pretrained=Qwen/Qwen2-VL-7B-Instruct,max_model_len=8192" --tasks chartqa --batch_size 1 --apply_chat_template --seed 42

Test Result

Qwen2-VL test results:

Before:

Tasks Version Filter n-shot Metric Value Stderr
chartqa 0 none 0 anywhere_accuracy 0.6860 ± 0.0093
none 0 exact_match 0.3476 ± 0.0095
none 0 relaxed_accuracy 0.4112 ± 0.0098

After:

Tasks Version Filter n-shot Metric Value Stderr
chartqa 0 none 0 anywhere_accuracy 0.6860 ± 0.0093
none 0 exact_match 0.3476 ± 0.0095
none 0 relaxed_accuracy 0.4112 ± 0.0098

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Note

Optimizes M-RoPE position generation by leveraging pre-parsed multimodal features and vectorized index math.

  • Adds iter_mm_grid_thw to iterate mm_features by mm_position.offset, yielding (t,h,w) grids and temporal scale (t_factor) for images/videos
  • Replaces token-scanning logic with offset-based segmentation and NumPy (np.indices, np.broadcast_to) to build 3D position IDs, then converts back to torch
  • Applies spatial merging (spatial_merge_size) and temporal scaling using second_per_grid_ts * tokens_per_second for video
  • Minor imports: introduce Iterator, add NumPy; updates both qwen2_vl.py and qwen2_5_vl.py consistently

Written by Cursor Bugbot for commit c370ddc7807818b64d98d78873931f675ab66d60. This will update automatically on new commits. Configure here.


Note

Optimizes and simplifies M-RoPE position computation by leveraging pre-parsed multimodal features and NumPy vectorization.

  • Adds iter_mm_grid_thw to iterate mm_features by mm_position.offset, yielding (t,h,w) grids and temporal scaling via second_per_grid_ts * tokens_per_second
  • Rewrites get_mrope_input_positions to segment by offsets and build positions with np.broadcast_to/np.indices, then converts to torch
  • Applies spatial merging (spatial_merge_size) and temporal scaling (t_factor) uniformly; removes token-id scanning and per-step torch index construction
  • Minor imports: introduce Iterator and numpy; identical updates in qwen2_vl.py and qwen2_5_vl.py

Written by Cursor Bugbot for commit 76c65502e40d29d100c94ed01a6daecb8e76a79c. This will update automatically on new commits. Configure here.


Note

Optimizes and simplifies M-RoPE position computation for Qwen2-VL and Qwen2.5-VL.

  • Adds iter_mm_grid_thw to iterate mm_features by mm_position.offset, yielding (t,h,w) and temporal scale (t_factor)
  • Replaces token scanning in get_mrope_input_positions with offset-based segmentation and NumPy (np.indices, np.broadcast_to), then converts back to torch
  • Applies spatial merging via spatial_merge_size and temporal scaling using second_per_grid_ts * tokens_per_second for video
  • Minor imports/typing: introduce Iterator and numpy; updates mirrored in qwen2_vl.py and qwen2_5_vl.py

Written by Cursor Bugbot for commit 23cc4d3365298890fc41111c1fac0a3a3bd40e24. This will update automatically on new commits. Configure here.


Note

Cursor Bugbot is generating a summary for commit 95645a5. Configure here.


Note

Optimizes and simplifies M-RoPE position id generation by leveraging pre-parsed multimodal features and NumPy.

  • Adds iter_mm_grid_thw to traverse mm_features by mm_position.offset, yielding (t,h,w) grids and temporal scale t_factor
  • Rewrites get_mrope_input_positions to segment by offsets and build 3D IDs with np.indices/np.broadcast_to, applying spatial merge and second_per_grid_ts * tokens_per_second, then converts back to torch
  • Minor imports/typing updates (Iterator, numpy); identical changes in qwen2_vl.py and qwen2_5_vl.py

Written by Cursor Bugbot for commit b38131b. This will update automatically on new commits. Configure here.


Note

Optimizes and simplifies M-RoPE position computation using pre-parsed multimodal features and NumPy.

  • Adds iter_mm_grid_thw to iterate mm_features by mm_position.offset, yielding (t,h,w) and temporal scale t_factor
  • Rewrites get_mrope_input_positions to segment by offsets and build 3D IDs via np.indices/np.broadcast_to, then converts to torch
  • Applies spatial merging (spatial_merge_size) and temporal scaling (second_per_grid_ts * tokens_per_second) for video
  • Minor imports/typing updates (Iterator, numpy); identical changes in qwen2_vl.py and qwen2_5_vl.py

Written by Cursor Bugbot for commit b862112. This will update automatically on new commits. Configure here.


Note

Optimizes and simplifies M-RoPE position id generation by leveraging pre-parsed multimodal features and vectorized index math.

  • Adds iter_mm_grid_thw to iterate mm_features by mm_position.offset, yielding (t,h,w) grids and temporal scale t_factor
  • Rewrites get_mrope_input_positions to segment by offsets and build 3D IDs via np.indices/np.broadcast_to, then converts to torch; removes token-id scanning and per-step torch index construction
  • Applies spatial merging (spatial_merge_size) and temporal scaling using second_per_grid_ts * tokens_per_second for video
  • Minor updates: import Iterator and numpy; identical changes in qwen2_vl.py and qwen2_5_vl.py

Written by Cursor Bugbot for commit b3a0317. This will update automatically on new commits. Configure here.

@YunzhuLu YunzhuLu requested a review from sighingnow as a code owner January 11, 2026 19:09
@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify mergify bot added the qwen Related to Qwen models label Jan 11, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the M-RoPE position computation for Qwen2-VL and Qwen2.5-VL models. The changes significantly simplify the logic by introducing a helper method iter_mm_grid_thw to iterate over multimodal features and by using NumPy for vectorized position calculations. This is a great improvement in terms of both readability and performance. I've found one minor issue with a type hint that should be corrected for static analysis correctness.

@YunzhuLu YunzhuLu force-pushed the refactor-mrope-position-computation branch 2 times, most recently from 76c6550 to 23cc4d3 Compare January 11, 2026 19:25
Signed-off-by: YunzhuLu <lucia.yunzhu@gmail.com>
@YunzhuLu YunzhuLu force-pushed the refactor-mrope-position-computation branch from 23cc4d3 to 95645a5 Compare January 12, 2026 00:43
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant optimization for M-RoPE position computation in Qwen2-VL and Qwen2.5-VL models. By leveraging pre-calculated multimodal features and vectorized NumPy operations, the new implementation is cleaner, more readable, and more efficient than the previous token-scanning approach. The introduction of the iter_mm_grid_thw helper function is a good abstraction. I've identified a potential crash when handling empty inputs, which was also present in the previous implementation, and have provided suggestions to address this edge case.

)

llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

np.concatenate will raise a ValueError if llm_pos_ids_list is empty (e.g., for an empty prompt), which will cause a crash. This should be handled to avoid unexpected failures on valid edge cases.

Suggested change
llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1) if llm_pos_ids_list else np.empty((3, 0), dtype=np.int64)


llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If llm_positions becomes an empty array after the change in the previous line, calling .max() on it will raise a ValueError. This also needs to be handled to prevent a crash.

Suggested change
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() if llm_positions.size > 0 else 0

)

llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

np.concatenate will raise a ValueError if llm_pos_ids_list is empty (e.g., for an empty prompt), which will cause a crash. This should be handled to avoid unexpected failures on valid edge cases.

Suggested change
llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1) if llm_pos_ids_list else np.empty((3, 0), dtype=np.int64)


llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If llm_positions becomes an empty array after the change in the previous line, calling .max() on it will raise a ValueError. This also needs to be handled to prevent a crash.

Suggested change
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() if llm_positions.size > 0 else 0

@YunzhuLu
Copy link
Copy Markdown
Contributor Author

/CC @DarkLight1337, thanks for taking a look!

@Isotr0py Isotr0py enabled auto-merge (squash) January 13, 2026 01:32
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jan 13, 2026
@Isotr0py Isotr0py disabled auto-merge January 13, 2026 02:57
@Isotr0py Isotr0py enabled auto-merge (squash) January 13, 2026 02:57
@YunzhuLu
Copy link
Copy Markdown
Contributor Author

Hi @Isotr0py,
The CI is stuck on a test_mcp_tools.py failure (due to insufficient GPU memory on the CI node), which is unrelated to M-RoPE changes.
Could you please help bypass this or manually merge? Thanks!

Copy link
Copy Markdown
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity - have you done profiling to quantify the perf gain from this PR (similar to #28730 (comment))? Thanks

@Isotr0py Isotr0py merged commit 542a405 into vllm-project:main Jan 13, 2026
56 checks passed
TomerBN-Nvidia pushed a commit to TomerBN-Nvidia/vllm that referenced this pull request Jan 13, 2026
vllm-project#32126)

Signed-off-by: YunzhuLu <lucia.yunzhu@gmail.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Tomer Natan <tbarnatan@computelab-frontend-8.nvidia.com>
@YunzhuLu
Copy link
Copy Markdown
Contributor Author

@ywang96 Here are the trace results for Qwen2-VL:

Image

Before

image ### After image

Video

Before

image ### After image

sammysun0711 pushed a commit to sammysun0711/vllm that referenced this pull request Jan 16, 2026
vllm-project#32126)

Signed-off-by: YunzhuLu <lucia.yunzhu@gmail.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
akh64bit pushed a commit to akh64bit/vllm that referenced this pull request Jan 16, 2026
vllm-project#32126)

Signed-off-by: YunzhuLu <lucia.yunzhu@gmail.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
vllm-project#32126)

Signed-off-by: YunzhuLu <lucia.yunzhu@gmail.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
Etelis pushed a commit to Etelis/vllm that referenced this pull request Jan 21, 2026
Refactor get_mrope_input_positions() to use mm_feature.mm_position.offset
directly instead of searching through input_tokens token-by-token.

Changes:
- Add iter_mm_features() iterator that yields (offset, modality, data)
  sorted by mm_position.offset for all 3 modalities (audio, image, video)
- Add _get_audio_for_video_mapping() for use_audio_in_video pairing
- Add _compute_audio_token_count() and _compute_interleaved_positions()
  helper methods
- Refactor get_mrope_input_positions() to iterate by offset using numpy
- Remove unused get_llm_pos_ids_for_vision import

Follows the pattern established in PR vllm-project#32126 for Qwen2-VL/2.5-VL.

Resolves: vllm-project#32656 (Qwen2.5-Omni item)
Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
@YunzhuLu YunzhuLu deleted the refactor-mrope-position-computation branch January 25, 2026 08:01
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
vllm-project#32126)

Signed-off-by: YunzhuLu <lucia.yunzhu@gmail.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants