[BUGFIX] Fix Pixtral consolidated format vision weight loading#39916
Conversation
There was a problem hiding this comment.
Code Review
This pull request adds support for Mistral native (consolidated) weight formats to the Pixtral model and introduces a new test case for consolidated loading. Feedback indicates that the added test uses a text-only model, which fails to exercise the vision encoder weight loading logic. Additionally, the weight loading implementation may fail to match keys containing a '.weight' suffix, and the remapping logic for native parameter names is inefficiently located and may lead to dropped weights.
Signed-off-by: Julien Denize <julien.denize@mistral.ai> Signed-off-by: juliendenize <julien.denize@mistral.ai>
e1ed624 to
b22a56b
Compare
| (".qkv_proj", ".v_proj", "v"), | ||
| (".gate_up_proj", ".gate_proj", 0), | ||
| (".gate_up_proj", ".up_proj", 1), | ||
| # Mistral native (consolidated) format |
There was a problem hiding this comment.
The wo and w2 parameters are handled via _vision_encoder_name_remap rather than through _vision_encoder_stacked_params. Since they're not sharded across TP ranks like qkv/w1/w3, they don't appear in the stacked params list. Is there a reason they couldn't be added to the stacked params list with their shard_id, or is the remap approach more robust to variations in how these keys appear across different checkpoint formats?
Hey thanks for the merge and to answer your question @NickLucche It does by making sure the output is not garbaged with a smaller gpu size than needed to Pixtral. I think I could even lower the size of the GPU but AFAIK it should be always used by the CI for 16 GB right ? So now we should catch whenever a regression happens ! |
…project#39916) Signed-off-by: Julien Denize <julien.denize@mistral.ai> Signed-off-by: juliendenize <julien.denize@mistral.ai>
…project#39916) Signed-off-by: Julien Denize <julien.denize@mistral.ai> Signed-off-by: juliendenize <julien.denize@mistral.ai>
Purpose
#36963 replaced the Pixtral vision encoder's nn.Linear layers (wq/wk/wv/wo/w1/w2/w3) with QKVParallelLinear and MergedColumnParallelLinear (qkv_proj/o_proj/gate_up_proj/down_proj) to support LoRA. However, the weight loading stacked_params only mapped HF-style names (q_proj, k_proj, etc.), not Mistral native names (wq, wk, etc.), causing vision encoder weights to be silently dropped when loading consolidated-format checkpoints.
Test Plan
Added a ministral test that is run for small GPUs instead of relying on Pixtral.
Test Result
Passing.
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.