[BugFix] KeyError when loading Mistral/vision-enabled checkpoints#33008
[BugFix] KeyError when loading Mistral/vision-enabled checkpoints#33008ms1design wants to merge 1 commit intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request effectively resolves a KeyError that occurred when loading certain Mistral vision-enabled checkpoints. The root cause was correctly identified as unsafe dictionary access in pixtral.py. The fix, which involves replacing direct dictionary access with the safer .get() method and adding None checks, is appropriate and well-implemented for both patch_merger and pre_mm_projector_norm weights. Additionally, the logic for handling different weight prefixes has been correctly updated. The changes are robust and directly address the reported bug without introducing any new issues. The code is now more resilient to variations in checkpoint weight names.
Signed-off-by: Mieszko Syty <mieszko@ms1design.pl>
c4e570a to
bbd3a2a
Compare
|
Resolved in recent commits. |
|
I can indeed reproduce @dbari for ML3 @ms1design can you reopen your PR ? Else I can take care of it. Edit: nvm the loading error does not come from the weight names but when vision is disabled (limit image = 0) but the model has vision weights. I'll push a fix. |
|
@juliendenize Hey, good catch with that, let me know if this ticket would be handy for you to carry on. |
Purpose
This PR fixes a bug introduced by PR #32780 where loading Mistral/vision-enabled checkpoints (e.g.,
mistralai/Devstral-Small-2-24B-Instruct-2512) fails with aKeyError: 'merging_layer.weight'.The refactor in PR #32780 moved Mistral-specific code from
llama.pyto a newmistral.pyfile, which is correct. However, thepixtral.pyfile had unsafe dictionary accesses that caused KeyErrors when loading checkpoints containingmulti_modal_projector.patch_mergerweights.Root Cause
In
pixtral.py'sload_weightsmethod, the code used direct dictionary access without null checks:This caused failures when:
multi_modal_projector.patch_merger.merging_layer.weightkeyspatch_mergermodule didn't have the expected parameter (or was None)is_patch_mergerfunction didn't recognizemulti_modal_projector.patch_mergerprefixChanges
1. Fixed unsafe dictionary access for patch_merger
Changed from direct access to safe access with null check:
2. Fixed unsafe dictionary access for pre_mm_projector_norm
Applied the same fix to
pre_mm_projector_norm_dictaccess.3. Improved patch_merger weight detection
Updated
is_patch_mergerfunction to recognize both prefixes:Test Plan
Reproduce the original issue:
This should fail with
KeyError: 'merging_layer.weight'before the fix.Verify the fix:
mistralai/Devstral-Small-2-24B-Instruct-2512model)Test with other Mistral models:
Test Result
✅ Before fix: Command fails with
KeyError: 'merging_layer.weight'✅ After fix: Command completes successfully, model loads and is ready for inference
✅ Regression test: Regular Mistral models still load correctly
✅ Multimodality: Works, we can use images in input to the model (tested with
mistralai/Devstral-Small-2-24B-Instruct-2512)Related Issues
merging_layer.weightwhen loading Mistral/vision-enabled checkpoints after PR #32780 refactor #32959Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.