[Bug Fix] Qwen3.5-nvfp4 MTP Speculative Decoding Weight Shape Mismatch#35675
[Bug Fix] Qwen3.5-nvfp4 MTP Speculative Decoding Weight Shape Mismatch#35675nguyen599 wants to merge 1 commit intovllm-project:mainfrom
Conversation
Signed-off-by: nguyen599 <nguyenmanh599123@gmail.com>
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
There was a problem hiding this comment.
Code Review
This pull request addresses a crash during weight loading for Qwen3.5-nvfp4 models when using MTP speculative decoding. The root cause appears to be a shape mismatch for the fc layer's weights in the Qwen3_5MultiTokenPredictor. The fix involves changing this layer from a ColumnParallelLinear to a ReplicatedLinear, which avoids tensor-parallel sharding for this specific weight and aligns with the expected weight format. Consequentially, quantization is disabled for this layer, and the forward pass is updated to correctly handle the layer's output. The changes are logical, well-contained, and directly resolve the issue described.
Cherry-pick upstream fixes for GB10 Spark (SM121): - PR vllm-project#35568: Recognize SM121 as SM120 family for Marlin/CUTLASS FP8 kernels (generate_kernels.py, ops.cu, scaled_mm*.cuh, marlin_utils.py) - PR vllm-project#35675: Fix Qwen3.5 MTP fc layer weight shape mismatch with NVFP4 by using ReplicatedLinear with quant_config=None - PR vllm-project#35833: FP8 KV cache for Triton MLA decode on Blackwell — adds on-the-fly FP8 dequantization in Triton kernels - PR vllm-project#35936: tool_choice="required" falls back to tool_parser for non-JSON (XML) tool calls from Qwen3 models Local patches: - Patch FlashInfer TRTLLM JIT to compile for SM12x (supported_major_versions=[10] → [10, 12]) - Skip VLLM_TEST_FORCE_FP8_MARLIN for NVFP4 MoE (not SM121-ready)
|
@nguyen599 isnt this already fixed in the latest branch? can you doublecheck? |
@voipmonitor I just check vllm at commit 6f0dd93. It still show error, reproduce error command: I used 1xH100 and model at txn545/Qwen3.5-35B-A3B-NVFP4. |
Cherry-pick upstream fixes for GB10 Spark (SM121): - PR vllm-project#35568: Recognize SM121 as SM120 family for Marlin/CUTLASS FP8 kernels (generate_kernels.py, ops.cu, scaled_mm*.cuh, marlin_utils.py) - PR vllm-project#35675: Fix Qwen3.5 MTP fc layer weight shape mismatch with NVFP4 by using ReplicatedLinear with quant_config=None - PR vllm-project#35833: FP8 KV cache for Triton MLA decode on Blackwell — adds on-the-fly FP8 dequantization in Triton kernels - PR vllm-project#35936: tool_choice="required" falls back to tool_parser for non-JSON (XML) tool calls from Qwen3 models Local patches: - Patch FlashInfer TRTLLM JIT to compile for SM12x (supported_major_versions=[10] → [10, 12]) - Skip VLLM_TEST_FORCE_FP8_MARLIN for NVFP4 MoE (not SM121-ready)
Cherry-pick upstream fixes for GB10 Spark (SM121): - PR vllm-project#35568: Recognize SM121 as SM120 family for Marlin/CUTLASS FP8 kernels (generate_kernels.py, ops.cu, scaled_mm*.cuh, marlin_utils.py) - PR vllm-project#35675: Fix Qwen3.5 MTP fc layer weight shape mismatch with NVFP4 by using ReplicatedLinear with quant_config=None - PR vllm-project#35833: FP8 KV cache for Triton MLA decode on Blackwell — adds on-the-fly FP8 dequantization in Triton kernels - PR vllm-project#35936: tool_choice="required" falls back to tool_parser for non-JSON (XML) tool calls from Qwen3 models Local patches: - Patch FlashInfer TRTLLM JIT to compile for SM12x (supported_major_versions=[10] → [10, 12]) - Skip VLLM_TEST_FORCE_FP8_MARLIN for NVFP4 MoE (not SM121-ready)
Cherry-pick upstream fixes for GB10 Spark (SM121): - PR vllm-project#35568: Recognize SM121 as SM120 family for Marlin/CUTLASS FP8 kernels (generate_kernels.py, ops.cu, scaled_mm*.cuh, marlin_utils.py) - PR vllm-project#35675: Fix Qwen3.5 MTP fc layer weight shape mismatch with NVFP4 by using ReplicatedLinear with quant_config=None - PR vllm-project#35833: FP8 KV cache for Triton MLA decode on Blackwell — adds on-the-fly FP8 dequantization in Triton kernels - PR vllm-project#35936: tool_choice="required" falls back to tool_parser for non-JSON (XML) tool calls from Qwen3 models Local patches: - Patch FlashInfer TRTLLM JIT to compile for SM12x (supported_major_versions=[10] → [10, 12]) - Skip VLLM_TEST_FORCE_FP8_MARLIN for NVFP4 MoE (not SM121-ready)
PR vllm-project#35675 equivalent (MTP fc layer fix) Updated qwen3_5_mtp.py Switched import from ColumnParallelLinear to ReplicatedLinear Changed FC construction from self.fc = ColumnParallelLinear(...) to self.fc = ReplicatedLinear(...) Removed TP-only args (gather_output, return_bias) Set quant_config=None for this layer Updated call site to unpack tuple: hidden_states, _ = self.fc(hidden_states) PR vllm-project#35936 equivalent (tool_choice="required" fallback) Updated engine/serving.py Replaced JSON parse suppress-block at elif request.tool_choice == "required": New flow: First try TypeAdapter(...).validate_json(content) On ValidationError or JSON decode error, fallback to configured tool parser when available Convert parsed tool calls into FunctionCall(...) entries Removed now-unused contextlib import Signed-off-by: ec-jt <james.trappett@elementalcompute.com>
With Qwen3.5-nvfp4, when launching with
--speculative-config '{"method": "qwen3_next_mtp", "num_speculative_tokens": 2}', the engine crashes during drafter weight loading:This PR addresses the weight shape mismatch issue in MTP speculative decoding.
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.