[BUGFIX] Fix accuracy regression for NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4 with TP>1#34476
Conversation
There was a problem hiding this comment.
Code Review
This pull request addresses an accuracy regression for quantized Mamba models with tensor parallelism greater than one. The root cause was a previous change that incorrectly unified weight loading logic, leading to a misalignment between weights and their quantization scales. This fix correctly reverts the problematic part of that change by restoring the use of MergedColumnParallelLinear when the number of groups is divisible by the tensor-parallel size. This is the right approach, as MergedColumnParallelLinear correctly handles the sharding of all parameters, including quantization scales, resolving the issue. The special handling for n_groups=1, which was already correct, is preserved. The change is well-justified, clearly explained, and effectively fixes the bug.
f6b04e0 to
7936233
Compare
…VFP4 with TP>1 (vllm-project#34476) Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com> Signed-off-by: athrael-soju <athrael-soju@users.noreply.github.com>
…VFP4 with TP>1 (vllm-project#34476) Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com> Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
…VFP4 with TP>1 (vllm-project#34476) Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com> Signed-off-by: Eldar Kurtic <research@neuralmagic.com>
…VFP4 with TP>1 (vllm-project#34476) Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com> Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
…VFP4 with TP>1 (vllm-project#34476) Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
…VFP4 with TP>1 (vllm-project#34476) Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Purpose
Fix accuracy regression for NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4 with TP>1.
Alternative to #34151.
What happened
PR #33257 (
a372f3f40) was a squash-merge of two logical changes:n_groups=1(e.g., Falcon-H1R-7B with FP8) — correct and needed.n_groups % tp_size == 0andn_groups == 1) to always useColumnParallelLinearwith a custommamba_v2_sharded_weight_loader— incorrect for quantized models.The second change (
d5d6d0b88) replacedMergedColumnParallelLinearwithColumnParallelLinear+ custom weight loader for ALL cases, but only overrode the weight_loader onin_proj.weight. For quantized models (NVFP4, FP8), there are also scale parameters (weight_scale,input_scale,weight_scale_2) that still useColumnParallelLinear's default contiguous sharding.The mamba
in_projhas a composite weight layout[gate, intermediate, B_groups, C_groups, dt_heads]. The custom mamba_loader shards each component separately, but the scale parameters do simple contiguous sharding, causing weight-scale misalignment — the dequantization scale at row N corresponds to a different model component than the weight at row N.MergedColumnParallelLinearhandles this correctly because it knows the per-component output sizes and shards ALL parameters (weights and scales) accordingly.Fix
Revert
d5d6d0b88("Unify MambaMixer2 TP sharding to use custom weight loader"), the last commit from #33257. This restoresMergedColumnParallelLinearfor then_groups % tp_size == 0case while preserving then_groups == 1quantized TP support.Test Results
Model:
nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4, TP=8, gsm8k 5-shot (limit=500), 8×B200After fix (3 runs):
3/3 runs correct.