Skip to content

Standardise get_rope to use rope_parameters["partial_rotary_factor"], not rotary_dim#30389

Merged
hmellor merged 17 commits intovllm-project:mainfrom
hmellor:simplify-get-rope
Dec 11, 2025
Merged

Standardise get_rope to use rope_parameters["partial_rotary_factor"], not rotary_dim#30389
hmellor merged 17 commits intovllm-project:mainfrom
hmellor:simplify-get-rope

Conversation

@hmellor
Copy link
Member

@hmellor hmellor commented Dec 10, 2025

Implements #30349 (comment).

Globally applies the fix from #30384.


rotary_dim as a config field only exists in GPT-J and a few custom models (Minimax for example). The Transformers library has standardised on partial_rotary_factor which lives in rope_parameters.

This PR removes the rotary_dim argument from get_rope so that there is now only one way to set the rotary_dim for RoPE:

  • Pass rope_parameters with partial_rotary_factors populated
  • get_rope does rotary_dim = head_dim * partial_rotary_factor internally

For the few edge cases that did not set config.rope_parameters["partial_rotary_factor"] it has been reverse engineered from config.rotary_dim.

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
@chatgpt-codex-connector
Copy link

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

@mergify mergify bot added deepseek Related to DeepSeek models llama Related to Llama models performance Performance-related issues qwen Related to Qwen models gpt-oss Related to GPT-OSS models labels Dec 10, 2025
@hmellor hmellor changed the title Simplify-get-rope Standardise get_rope to use rope_parameters["partial_rotary_factor"], not rotary_dim Dec 10, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request simplifies the get_rope function by removing the rotary_dim argument and deriving it from head_size and partial_rotary_factor instead. The changes are applied across a large number of model files. The refactoring is well-executed, but I've found several instances where the shared config object is modified in-place. This can lead to unexpected side effects and should be avoided. I've provided suggestions to create copies of rope_parameters before modification to prevent this.

@github-project-automation github-project-automation bot moved this from To Triage to Ready in gpt-oss Issues & Enhancements Dec 10, 2025
@DarkLight1337
Copy link
Member

Can you resolve the merge conflict?

@mergify
Copy link

mergify bot commented Dec 10, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @hmellor.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Dec 10, 2025
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
@hmellor hmellor enabled auto-merge (squash) December 10, 2025 13:15
@mergify mergify bot removed the needs-rebase label Dec 10, 2025
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
@hmellor hmellor linked an issue Dec 11, 2025 that may be closed by this pull request
1 task
…n-standard names

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
@hmellor hmellor merged commit cf3eacf into vllm-project:main Dec 11, 2025
60 checks passed
@hmellor hmellor deleted the simplify-get-rope branch December 11, 2025 20:45
Majid-Taheri pushed a commit to Majid-Taheri/vllm that referenced this pull request Dec 23, 2025
…"]`, not `rotary_dim` (vllm-project#30389)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Ubuntu <mjtaheri68@gmail.com>
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
…"]`, not `rotary_dim` (vllm-project#30389)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models gpt-oss Related to GPT-OSS models llama Related to Llama models performance Performance-related issues qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

[Bug]: QuantTrio/MiniMax-M2-AWQ produces garbage in 12/10/2025 build

5 participants