Skip to content

[Bug Fix] MTP Speculative Decoding with NVFP4: Weight Shape Mismatch#35041

Open
jagguvarma15 wants to merge 5 commits intovllm-project:mainfrom
jagguvarma15:fix/mtp-nvfp4-weight-shape-mismatch
Open

[Bug Fix] MTP Speculative Decoding with NVFP4: Weight Shape Mismatch#35041
jagguvarma15 wants to merge 5 commits intovllm-project:mainfrom
jagguvarma15:fix/mtp-nvfp4-weight-shape-mismatch

Conversation

@jagguvarma15
Copy link

Fix: MTP Speculative Decoding with NVFP4 Weight Shape Mismatch

Fixes #35031

Summary

The eh_proj layer in multiple MTP (Multi-Token Prediction) model files was defined as a plain nn.Linear, which does not participate in vLLM's quantization framework. When NVFP4 quantized checkpoints are loaded, the weights have a different shape (packed uint8 with halved input dimension for FP4 packing), causing a weight shape mismatch error.

Root Cause

nn.Linear is not recognized by ModelOptQuantConfigBase.get_quant_method(), which only handles LinearBase subclasses. This means eh_proj never gets a quantization method assigned, so it expects unquantized weight shapes — but the checkpoint provides NVFP4-quantized weights with different dimensions.

Fix

Replaced nn.Linear with ReplicatedLinear (a LinearBase subclass) for the eh_proj layer across 6 MTP model files. ReplicatedLinear properly integrates with vLLM's quantization system, accepting a quant_config parameter so it correctly handles quantized weight formats including NVFP4.

The existing longcat_flash_mtp.py already used ReplicatedLinear correctly and served as the reference for this fix.

Files Changed

File Change
vllm/model_executor/models/deepseek_mtp.py nn.LinearReplicatedLinear for eh_proj
vllm/model_executor/models/step3p5_mtp.py nn.LinearReplicatedLinear for eh_proj
vllm/model_executor/models/openpangu_mtp.py nn.LinearReplicatedLinear for eh_proj
vllm/model_executor/models/glm_ocr_mtp.py nn.LinearReplicatedLinear for eh_proj
vllm/model_executor/models/glm4_moe_mtp.py nn.LinearReplicatedLinear for eh_proj
vllm/model_executor/models/glm4_moe_lite_mtp.py nn.LinearReplicatedLinear for eh_proj

Changes per file

  1. Added ReplicatedLinear import from vllm.model_executor.layers.linear
  2. Replaced nn.Linear(hidden_size * 2, hidden_size, bias=False) with ReplicatedLinear(..., quant_config=quant_config, prefix=...)
  3. Updated forward() to unpack the tuple return: hidden_states, _ = self.eh_proj(...)

Assisted by Claude Opus 4.6

…llm-project#35031)

Replace nn.Linear with ReplicatedLinear for eh_proj in MTP models to
fix weight shape mismatch when using NVFP4 quantization.

The eh_proj layer in multiple MTP model files was defined as a plain
nn.Linear, which does not participate in vLLM's quantization framework.
When NVFP4 quantized checkpoints are loaded, the weights have a
different shape (packed uint8 with halved input dimension), causing a
shape mismatch error.

ReplicatedLinear (a LinearBase subclass) properly integrates with the
quantization system, accepting a quant_config parameter so it correctly
handles quantized weight formats including NVFP4.

Files changed:
- deepseek_mtp.py
- step3p5_mtp.py
- openpangu_mtp.py
- glm_ocr_mtp.py
- glm4_moe_mtp.py
- glm4_moe_lite_mtp.py

Signed-off-by: Jagadeesh Varma <jagguvarma15@gmail.com>
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@dosubot
Copy link

dosubot bot commented Feb 22, 2026

Related Documentation

Checked 0 published document(s) in 1 knowledge base(s). No updates required.

How did I do? Any feedback?  Join Discord

@mergify mergify bot added deepseek Related to DeepSeek models bug Something isn't working labels Feb 22, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses the weight shape mismatch issue in MTP speculative decoding when using NVFP4 quantization. By replacing nn.Linear with ReplicatedLinear for the eh_proj layer across multiple MTP model files, the code now properly integrates with vLLM's quantization framework. The updates to the forward methods to handle the tuple return of ReplicatedLinear are also correctly implemented. The changes are consistent and follow established patterns in the repository.

scottgl9 added a commit to scottgl9/vllm that referenced this pull request Mar 2, 2026
Add comprehensive performance analysis for MiniMax-M2.5-REAP-139B-A10B-NVFP4-GB10:

Architecture confirmed:
- Attention IS NVFP4 in this model (ignore list = only lm_head + MoE gates)
- 3 MTP modules present (layers 62-64) — biggest performance lever available
- Per-step weight load: ~6.15 GB → 36–44 tok/s theoretical ceiling on GB10

Performance gap analysis:
- Current: 24 tok/s on Strix Halo (AMD); GB10 expected similar baseline
- vLLM is 1.78x slower than SGLang at BS=1 for NVFP4 MoE (documented gap)
- Gap sources: activation quant overhead, kernel launch overhead, no fused
  shuffle+reduce in MoE, generic CUTLASS configs

Key new PRs to integrate:
- vllm-project#35041 (OPEN): MTP+NVFP4 weight shape mismatch — required for MTP+NVFP4
- vllm-project#35442 (OPEN): Non-blocking MTP token copy — 6ms→200µs CPU-GPU sync
- vllm-project#33303 (OPEN): MiniMax PP+DP for multi-Spark scaling

Already-merged PRs confirmed in HEAD:
- vllm-project#34718 (act_quant_fusion.py): SiLU+FP4 fusion
- vllm-project#34899 (allreduce_rms_fusion.py): NVFP4 AR+Norm fusion
- vllm-project#30885: 8x4 SF tiling (not yet effective on GB10 — TRTLLM backend blocked)
scottgl9 added a commit to scottgl9/vllm that referenced this pull request Mar 4, 2026
Add comprehensive performance analysis for MiniMax-M2.5-REAP-139B-A10B-NVFP4-GB10:

Architecture confirmed:
- Attention IS NVFP4 in this model (ignore list = only lm_head + MoE gates)
- 3 MTP modules present (layers 62-64) — biggest performance lever available
- Per-step weight load: ~6.15 GB → 36–44 tok/s theoretical ceiling on GB10

Performance gap analysis:
- Current: 24 tok/s on Strix Halo (AMD); GB10 expected similar baseline
- vLLM is 1.78x slower than SGLang at BS=1 for NVFP4 MoE (documented gap)
- Gap sources: activation quant overhead, kernel launch overhead, no fused
  shuffle+reduce in MoE, generic CUTLASS configs

Key new PRs to integrate:
- vllm-project#35041 (OPEN): MTP+NVFP4 weight shape mismatch — required for MTP+NVFP4
- vllm-project#35442 (OPEN): Non-blocking MTP token copy — 6ms→200µs CPU-GPU sync
- vllm-project#33303 (OPEN): MiniMax PP+DP for multi-Spark scaling

Already-merged PRs confirmed in HEAD:
- vllm-project#34718 (act_quant_fusion.py): SiLU+FP4 fusion
- vllm-project#34899 (allreduce_rms_fusion.py): NVFP4 AR+Norm fusion
- vllm-project#30885: 8x4 SF tiling (not yet effective on GB10 — TRTLLM backend blocked)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working deepseek Related to DeepSeek models

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: MTP Speculative Decoding with NVFP4: Weight Shape Mismatch

1 participant