Skip to content

[Llama.py -> mistral.py] Extract mistral-only relevant code into separate file#32780

Merged
DarkLight1337 merged 11 commits intovllm-project:mainfrom
patrickvonplaten:move_mistral_into_its_own_file
Jan 22, 2026
Merged

[Llama.py -> mistral.py] Extract mistral-only relevant code into separate file#32780
DarkLight1337 merged 11 commits intovllm-project:mainfrom
patrickvonplaten:move_mistral_into_its_own_file

Conversation

@patrickvonplaten
Copy link
Copy Markdown
Collaborator

@patrickvonplaten patrickvonplaten commented Jan 21, 2026

We're adding more and more mistral-only code to the llama.py class which makes it harder to read and creates possible future unwanted dependencies. E.g. if other models depend on the llama.py class one might think that mistral-only code might also be relevant for such classes and thus make vLLM too rigid.

Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
@patrickvonplaten patrickvonplaten force-pushed the move_mistral_into_its_own_file branch from 1141f58 to 9913155 Compare January 21, 2026 13:44
@mergify mergify bot added llama Related to Llama models new-model Requests to new models labels Jan 21, 2026

# This function is used to remap the mistral format as
# used by Mistral and Llama <=2
def maybe_remap_mistral(
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this code comes from a very old PR: #8168 and I'm quite convinced that it's only mistral checkpoints that actually make use of this function so moving it out

prefix=f"{prefix}.attn",
)

def _get_llama_4_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only mistral makes use of llama_4 scaling

assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
head_dim = getattr(config, "head_dim", None)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

afaik only mistral-nemo every used this

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't seem to be the case 😅

@mergify
Copy link
Copy Markdown

mergify bot commented Jan 21, 2026

Hi @patrickvonplaten, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request successfully extracts Mistral-specific model adaptations into a new file, mistral.py, and refactors llama.py to be more generic. This improves modularity and maintainability of the codebase. The changes in llama.py and registry.py are appropriate for this refactoring.

I am having trouble creating individual review comments. Click here to see my feedback.

vllm/model_executor/mistral.py (214-233)

high

The logic within maybe_remap_mistral for handling wq and wk weights, especially with the conditional checks for qscale_weight and loaded_weight.numel() > 1, is quite complex and repetitive. This intricate logic increases the potential for errors and makes future modifications or debugging challenging. Consider refactoring this section to improve clarity and reduce duplication, perhaps by abstracting the common permutation and conditional checks into smaller, more focused helper functions.

Copy link
Copy Markdown
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the cleanup!

@DarkLight1337 DarkLight1337 added the ready ONLY add when PR is ready to merge/full CI is needed label Jan 21, 2026
@patrickvonplaten
Copy link
Copy Markdown
Collaborator Author

Hmm the docker building is probably flaky no @DarkLight1337 ?

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) January 21, 2026 15:31
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
@patrickvonplaten
Copy link
Copy Markdown
Collaborator Author

Think final failing tests are unrelated:

[2026-01-21T19:36:47Z] ERROR entrypoints/openai/test_serving_chat.py::TestGPTOSSChat::test_gpt_oss_chat_tool_call_streaming[with_tool_parser-exclude_tools_when_tool_choice_none] - RuntimeError: Server failed to start in time.
[2026-01-21T18:50:16Z] ERROR entrypoints/openai/responses/test_parsable_context.py::test_basic[Qwen/Qwen3-8B] - RuntimeError: Server failed to start in time.

good to merge you think @DarkLight1337 ?

@DarkLight1337 DarkLight1337 merged commit 1579c9b into vllm-project:main Jan 22, 2026
53 checks passed
monajafi-amd pushed a commit to monajafi-amd/vllm that referenced this pull request Jan 23, 2026
…rate file (vllm-project#32780)

Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
Signed-off-by: mohammad najafi <mohammad.najafi@amd.com>
ms1design added a commit to ms1design/vllm that referenced this pull request Jan 24, 2026
…llm-project#32780

The refactor in PR vllm-project#32780 moved Mistral-specific code to mistral.py,

but pixtral.py had unsafe dictionary accesses that caused KeyError

when loading checkpoints with multi_modal_projector.patch_merger weights.

Changes:

- Use .get() instead of direct access for patch_merger_dict

- Use .get() instead of direct access for pre_mm_projector_norm_dict

- Improved is_patch_merger() to recognize multi_modal_projector.patch_merger prefix

- Added null checks to gracefully handle missing weights

Fixes vllm-project#32959

Signed-off-by: Mieszko Syty <mieszko@ms1design.pl>
cwazai pushed a commit to cwazai/vllm that referenced this pull request Jan 25, 2026
…rate file (vllm-project#32780)

Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
Signed-off-by: 陈建华 <1647430658@qq.com>
lapy pushed a commit to lapy/vllm that referenced this pull request Jan 27, 2026
…rate file (vllm-project#32780)

Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
…rate file (vllm-project#32780)

Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

llama Related to Llama models new-model Requests to new models ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants