Skip to content

[Bugfix] Fix qwen-moe packed_modules_mapping#26634

Merged
Isotr0py merged 4 commits intovllm-project:mainfrom
jeejeelee:fix-qwen-moe-mapping
Oct 11, 2025
Merged

[Bugfix] Fix qwen-moe packed_modules_mapping#26634
Isotr0py merged 4 commits intovllm-project:mainfrom
jeejeelee:fix-qwen-moe-mapping

Conversation

@jeejeelee
Copy link
Copy Markdown
Collaborator

@jeejeelee jeejeelee commented Oct 11, 2025

Purpose

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
@jeejeelee jeejeelee requested a review from sighingnow as a code owner October 11, 2025 11:43
@mergify mergify bot added the qwen Related to Qwen models label Oct 11, 2025
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to fix an issue with packed_modules_mapping for Qwen MoE models. The change conditionally adds gate_up_proj to the mapping. However, the implementation introduces a critical bug by modifying a class attribute (packed_modules_mapping) from an instance, which can cause state to leak between different model instances. Additionally, the condition used to determine the existence of dense MLP layers is not robust and can fail for certain model configurations. I've provided comments with suggested fixes for both qwen2_moe.py and qwen3_moe.py to address these issues by creating an instance-specific copy of the mapping and using a more accurate condition.

Comment on lines +546 to +553
# Only perform the following mapping when Qwen2MoeMLP exists
if getattr(config, "mlp_only_layers", []):
self.packed_modules_mapping["gate_up_proj"] = (
[
"gate_proj",
"up_proj",
],
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change introduces two critical issues:

  1. Modification of a class attribute: self.packed_modules_mapping is modified in-place. Since packed_modules_mapping is a class attribute, this modification will affect all other instances of Qwen2MoeForCausalLM, which can lead to unexpected behavior if multiple models with different configurations are used in the same process. An instance-specific copy should be created before modification.

  2. Incorrect condition for MLP existence: The condition if getattr(config, "mlp_only_layers", []) is not sufficient to determine if Qwen2MoeMLP layers (and thus gate_up_proj) exist. For example, a model with decoder_sparse_step > 1 and an empty mlp_only_layers list will have dense MLP layers, but this condition will be false, incorrectly omitting gate_up_proj from the mapping.

A more robust approach is to check if not all layers are sparse MoE layers. This is the case if mlp_only_layers is non-empty, or if there are no experts, or if decoder_sparse_step is not 1. The suggested change below addresses both issues.

        # Create a copy of the mapping to avoid modifying the class attribute.
        self.packed_modules_mapping = self.packed_modules_mapping.copy()
        # Conditionally add gate_up_proj if dense MLP layers exist. A model has
        # dense MLP layers if not all layers are sparse MoE layers.
        if (bool(getattr(config, "mlp_only_layers", [])) or
                getattr(config, "num_experts", 0) == 0 or
                getattr(config, "decoder_sparse_step", 1) != 1):
            self.packed_modules_mapping["gate_up_proj"] = [
                "gate_proj",
                "up_proj",
            ]

Comment on lines +648 to +655
# Only perform the following mapping when Qwen3MoeMLP exists
if getattr(config, "mlp_only_layers", []):
self.packed_modules_mapping["gate_up_proj"] = (
[
"gate_proj",
"up_proj",
],
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change introduces two critical issues:

  1. Modification of a class attribute: self.packed_modules_mapping is modified in-place. Since packed_modules_mapping is a class attribute, this modification will affect all other instances of Qwen3MoeForCausalLM, which can lead to unexpected behavior if multiple models with different configurations are used in the same process. An instance-specific copy should be created before modification.

  2. Incorrect condition for MLP existence: The condition if getattr(config, "mlp_only_layers", []) is not sufficient to determine if Qwen3MoeMLP layers (and thus gate_up_proj) exist. For example, a model with decoder_sparse_step > 1 and an empty mlp_only_layers list will have dense MLP layers, but this condition will be false, incorrectly omitting gate_up_proj from the mapping.

A more robust approach is to check if not all layers are sparse MoE layers. This is the case if mlp_only_layers is non-empty, or if there are no experts, or if decoder_sparse_step is not 1. The suggested change below addresses both issues.

Suggested change
# Only perform the following mapping when Qwen3MoeMLP exists
if getattr(config, "mlp_only_layers", []):
self.packed_modules_mapping["gate_up_proj"] = (
[
"gate_proj",
"up_proj",
],
)
# Create a copy of the mapping to avoid modifying the class attribute.
self.packed_modules_mapping = self.packed_modules_mapping.copy()
# Conditionally add gate_up_proj if dense MLP layers exist. A model has
# dense MLP layers if not all layers are sparse MoE layers.
if (bool(getattr(config, "mlp_only_layers", [])) or
getattr(config, "num_experts", 0) == 0 or
getattr(config, "decoder_sparse_step", 1) != 1):
self.packed_modules_mapping["gate_up_proj"] = [
"gate_proj",
"up_proj",
]

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Comment on lines +548 to +552
self.packed_modules_mapping["gate_up_proj"] = (
[
"gate_proj",
"up_proj",
],
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Assign gate_up_proj mapping as list, not tuple

The new conditional adds gate_up_proj using self.packed_modules_mapping["gate_up_proj"] = (["gate_proj", "up_proj"],). Because of the parentheses and trailing comma this stores a tuple whose only element is a list, while the rest of the quantization helpers expect dict[str, list[str]]. When the mapping is consumed (e.g., get_layer_partition_names or LoRA utilities), the tuple is iterated and the list itself is passed to string operations such as removesuffix/replace, raising a TypeError. Any model with mlp_only_layers set will fail during packed-module handling. Assign the list directly without wrapping it in a tuple.

Useful? React with 👍 / 👎.

Comment on lines +649 to +653
if getattr(config, "mlp_only_layers", []):
self.packed_modules_mapping["gate_up_proj"] = (
[
"gate_proj",
"up_proj",
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid tuple-wrapping gate_up_proj mapping

Same issue as above: self.packed_modules_mapping["gate_up_proj"] is assigned (["gate_proj", "up_proj"],), producing a tuple instead of the list that downstream quantization and LoRA helpers expect. Iterating this mapping yields the list itself and causes type errors when string concatenation is attempted, so models with mlp_only_layers enabled will crash when retrieving partition names or applying packed transforms. Assign a plain list here.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, packed_modules_mapping is supposed to be a ClassVar. Would editing it after initialization cause problems?

@jeejeelee
Copy link
Copy Markdown
Collaborator Author

Previously checked, packed_modules_mapping is only used after instantiation, so I think it should be safe.

@DarkLight1337
Copy link
Copy Markdown
Member

Can you update the interface definition then?

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
@jeejeelee
Copy link
Copy Markdown
Collaborator Author

Done in bbf7ef7

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) October 11, 2025 12:44
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 11, 2025
Comment on lines +546 to +552
# Only perform the following mapping when Qwen2MoeMLP exists
if getattr(config, "mlp_only_layers", []):
self.packed_modules_mapping["gate_up_proj"] = (
[
"gate_proj",
"up_proj",
],
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this condition doesn't really fit Qwen2MoE's case. Because Qwen2MoE will also have shared expert needed packing inside sparse moe block:

if config.shared_expert_intermediate_size > 0:
self.shared_expert = Qwen2MoeMLP(
hidden_size=config.hidden_size,
intermediate_size=config.shared_expert_intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False,
expert_gate=self.shared_expert_gate,
prefix=f"{prefix}.shared_expert",
)
else:
self.shared_expert = None

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, fixed in 2b0ae9a

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
@Isotr0py Isotr0py enabled auto-merge (squash) October 11, 2025 13:34
@Isotr0py Isotr0py merged commit f0a30a0 into vllm-project:main Oct 11, 2025
54 checks passed
@jeejeelee jeejeelee deleted the fix-qwen-moe-mapping branch October 11, 2025 16:06
@hmellor
Copy link
Copy Markdown
Member

hmellor commented Oct 12, 2025

The Codex review was correct. You have assigned a tuple[list[str]] to the value in packed_modules_mapping when it should be just list[str].

This PR causes pytest tests/evals/gsm8k/test_gsm8k_correctness.py::test_gsm8k_correctness_param[Qwen1.5-MoE-W4A16-CT-tp1] to be unrunnable.

@hmellor
Copy link
Copy Markdown
Member

hmellor commented Oct 12, 2025

I have fixed it in #26633, but I still can't seem to run that eval test. I get:

[core.py:792] EngineCore encountered a fatal error.
[core.py:792] Traceback (most recent call last):
[core.py:792]   File "/home/harry/vllm/vllm/v1/engine/core.py", line 783, in run_engine_core
[core.py:792]     engine_core.run_busy_loop()
[core.py:792]   File "/home/harry/vllm/vllm/v1/engine/core.py", line 810, in run_busy_loop
[core.py:792]     self._process_engine_step()
[core.py:792]   File "/home/harry/vllm/vllm/v1/engine/core.py", line 839, in _process_engine_step
[core.py:792]     outputs, model_executed = self.step_fn()
[core.py:792]                               ^^^^^^^^^^^^^^
[core.py:792]   File "/home/harry/vllm/vllm/v1/engine/core.py", line 320, in step
[core.py:792]     scheduler_output = self.scheduler.schedule()
[core.py:792]                        ^^^^^^^^^^^^^^^^^^^^^^^^^
[core.py:792]   File "/home/harry/vllm/vllm/v1/core/sched/scheduler.py", line 256, in schedule
[core.py:792]     new_blocks = self.kv_cache_manager.allocate_slots(
[core.py:792]                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[core.py:792]   File "/home/harry/vllm/vllm/v1/core/kv_cache_manager.py", line 317, in allocate_slots
[core.py:792]     self.coordinator.cache_blocks(request, num_tokens_to_cache)
[core.py:792]   File "/home/harry/vllm/vllm/v1/core/kv_cache_coordinator.py", line 138, in cache_blocks
[core.py:792]     manager.cache_blocks(request, num_computed_tokens)
[core.py:792]   File "/home/harry/vllm/vllm/v1/core/single_type_kv_cache_manager.py", line 156, in cache_blocks
[core.py:792]     self.block_pool.cache_full_blocks(
[core.py:792]   File "/home/harry/vllm/vllm/v1/core/block_pool.py", line 232, in cache_full_blocks
[core.py:792]     assert blk.block_hash is None
[core.py:792]            ^^^^^^^^^^^^^^^^^^^^^^
[core.py:792] AssertionError
[async_llm.py:518] AsyncLLM output_handler failed.
[async_llm.py:518] Traceback (most recent call last):
[async_llm.py:518]   File "/home/harry/vllm/vllm/v1/engine/async_llm.py", line 472, in output_handler
[async_llm.py:518]     outputs = await engine_core.get_output_async()
[async_llm.py:518]               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[async_llm.py:518]   File "/home/harry/vllm/vllm/v1/engine/core_client.py", line 882, in get_output_async
[async_llm.py:518]     raise self._format_exception(outputs) from None
[async_llm.py:518] vllm.v1.engine.exceptions.EngineDeadError: EngineCore encountered an issue. See stack trace (above) for the root cause.

But can't reproduce it outside of this test.

1994 pushed a commit to 1994/vllm that referenced this pull request Oct 14, 2025
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: 1994 <1994@users.noreply.github.com>
Dhruvilbhatt pushed a commit to Dhruvilbhatt/vllm that referenced this pull request Oct 14, 2025
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Dhruvil Bhatt <bhattdbh@amazon.com>
bbartels pushed a commit to bbartels/vllm that referenced this pull request Oct 16, 2025
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: bbartels <benjamin@bartels.dev>
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
alhridoy pushed a commit to alhridoy/vllm that referenced this pull request Oct 24, 2025
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
rtourgeman pushed a commit to rtourgeman/vllm that referenced this pull request Nov 10, 2025
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants