[Bug Fix] MTP Speculative Decoding with NVFP4: Weight Shape Mismatch#35041
[Bug Fix] MTP Speculative Decoding with NVFP4: Weight Shape Mismatch#35041jagguvarma15 wants to merge 5 commits intovllm-project:mainfrom
Conversation
…llm-project#35031) Replace nn.Linear with ReplicatedLinear for eh_proj in MTP models to fix weight shape mismatch when using NVFP4 quantization. The eh_proj layer in multiple MTP model files was defined as a plain nn.Linear, which does not participate in vLLM's quantization framework. When NVFP4 quantized checkpoints are loaded, the weights have a different shape (packed uint8 with halved input dimension), causing a shape mismatch error. ReplicatedLinear (a LinearBase subclass) properly integrates with the quantization system, accepting a quant_config parameter so it correctly handles quantized weight formats including NVFP4. Files changed: - deepseek_mtp.py - step3p5_mtp.py - openpangu_mtp.py - glm_ocr_mtp.py - glm4_moe_mtp.py - glm4_moe_lite_mtp.py Signed-off-by: Jagadeesh Varma <jagguvarma15@gmail.com>
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
There was a problem hiding this comment.
Code Review
This pull request correctly addresses the weight shape mismatch issue in MTP speculative decoding when using NVFP4 quantization. By replacing nn.Linear with ReplicatedLinear for the eh_proj layer across multiple MTP model files, the code now properly integrates with vLLM's quantization framework. The updates to the forward methods to handle the tuple return of ReplicatedLinear are also correctly implemented. The changes are consistent and follow established patterns in the repository.
Add comprehensive performance analysis for MiniMax-M2.5-REAP-139B-A10B-NVFP4-GB10: Architecture confirmed: - Attention IS NVFP4 in this model (ignore list = only lm_head + MoE gates) - 3 MTP modules present (layers 62-64) — biggest performance lever available - Per-step weight load: ~6.15 GB → 36–44 tok/s theoretical ceiling on GB10 Performance gap analysis: - Current: 24 tok/s on Strix Halo (AMD); GB10 expected similar baseline - vLLM is 1.78x slower than SGLang at BS=1 for NVFP4 MoE (documented gap) - Gap sources: activation quant overhead, kernel launch overhead, no fused shuffle+reduce in MoE, generic CUTLASS configs Key new PRs to integrate: - vllm-project#35041 (OPEN): MTP+NVFP4 weight shape mismatch — required for MTP+NVFP4 - vllm-project#35442 (OPEN): Non-blocking MTP token copy — 6ms→200µs CPU-GPU sync - vllm-project#33303 (OPEN): MiniMax PP+DP for multi-Spark scaling Already-merged PRs confirmed in HEAD: - vllm-project#34718 (act_quant_fusion.py): SiLU+FP4 fusion - vllm-project#34899 (allreduce_rms_fusion.py): NVFP4 AR+Norm fusion - vllm-project#30885: 8x4 SF tiling (not yet effective on GB10 — TRTLLM backend blocked)
Add comprehensive performance analysis for MiniMax-M2.5-REAP-139B-A10B-NVFP4-GB10: Architecture confirmed: - Attention IS NVFP4 in this model (ignore list = only lm_head + MoE gates) - 3 MTP modules present (layers 62-64) — biggest performance lever available - Per-step weight load: ~6.15 GB → 36–44 tok/s theoretical ceiling on GB10 Performance gap analysis: - Current: 24 tok/s on Strix Halo (AMD); GB10 expected similar baseline - vLLM is 1.78x slower than SGLang at BS=1 for NVFP4 MoE (documented gap) - Gap sources: activation quant overhead, kernel launch overhead, no fused shuffle+reduce in MoE, generic CUTLASS configs Key new PRs to integrate: - vllm-project#35041 (OPEN): MTP+NVFP4 weight shape mismatch — required for MTP+NVFP4 - vllm-project#35442 (OPEN): Non-blocking MTP token copy — 6ms→200µs CPU-GPU sync - vllm-project#33303 (OPEN): MiniMax PP+DP for multi-Spark scaling Already-merged PRs confirmed in HEAD: - vllm-project#34718 (act_quant_fusion.py): SiLU+FP4 fusion - vllm-project#34899 (allreduce_rms_fusion.py): NVFP4 AR+Norm fusion - vllm-project#30885: 8x4 SF tiling (not yet effective on GB10 — TRTLLM backend blocked)
Fix: MTP Speculative Decoding with NVFP4 Weight Shape Mismatch
Fixes #35031
Summary
The
eh_projlayer in multiple MTP (Multi-Token Prediction) model files was defined as a plainnn.Linear, which does not participate in vLLM's quantization framework. When NVFP4 quantized checkpoints are loaded, the weights have a different shape (packeduint8with halved input dimension for FP4 packing), causing a weight shape mismatch error.Root Cause
nn.Linearis not recognized byModelOptQuantConfigBase.get_quant_method(), which only handlesLinearBasesubclasses. This meanseh_projnever gets a quantization method assigned, so it expects unquantized weight shapes — but the checkpoint provides NVFP4-quantized weights with different dimensions.Fix
Replaced
nn.LinearwithReplicatedLinear(aLinearBasesubclass) for theeh_projlayer across 6 MTP model files.ReplicatedLinearproperly integrates with vLLM's quantization system, accepting aquant_configparameter so it correctly handles quantized weight formats including NVFP4.The existing
longcat_flash_mtp.pyalready usedReplicatedLinearcorrectly and served as the reference for this fix.Files Changed
vllm/model_executor/models/deepseek_mtp.pynn.Linear→ReplicatedLinearforeh_projvllm/model_executor/models/step3p5_mtp.pynn.Linear→ReplicatedLinearforeh_projvllm/model_executor/models/openpangu_mtp.pynn.Linear→ReplicatedLinearforeh_projvllm/model_executor/models/glm_ocr_mtp.pynn.Linear→ReplicatedLinearforeh_projvllm/model_executor/models/glm4_moe_mtp.pynn.Linear→ReplicatedLinearforeh_projvllm/model_executor/models/glm4_moe_lite_mtp.pynn.Linear→ReplicatedLinearforeh_projChanges per file
ReplicatedLinearimport fromvllm.model_executor.layers.linearnn.Linear(hidden_size * 2, hidden_size, bias=False)withReplicatedLinear(..., quant_config=quant_config, prefix=...)forward()to unpack the tuple return:hidden_states, _ = self.eh_proj(...)Assisted by Claude Opus 4.6