-
Notifications
You must be signed in to change notification settings - Fork 33.6k
Dynamic weight conversion is recursive #44300
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 17 commits
2c1d904
0caafab
edb4d27
cbbecf4
1fd89ae
0a4a829
af57e2e
ea3ac63
fbda051
88f18be
6eb0e13
b9f29ab
d452817
500e96b
3fcf527
23a3e13
42e85f0
e85b103
ce82bcc
d5ab4fd
dcc95d5
f4b5888
215da83
370feb6
3c6a23f
b3559f7
960c716
52560f6
17294cc
7503c12
aac4bba
1c770a1
6780970
a1220c9
5a68bf7
d10cb69
abb8001
503c206
f92c063
ca68663
4c65203
cc19ab9
7e3d40e
29e600a
df86ff7
ebcb04f
99f0a05
58e0d1c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -55,19 +55,64 @@ | |
| "qwen3_omni_moe": "qwen2_moe", | ||
| "qwen3_omni_moe_thinker": "qwen2_moe", | ||
| "qwen3_next": "qwen2_moe", | ||
| "qwen3_5_moe": "qwen2_moe", | ||
| "hunyuan_v1_moe": "qwen2_moe", | ||
| "flex_olmo": "qwen2_moe", | ||
| "olmoe": "qwen2_moe", | ||
| "exaone_moe": "qwen2_moe", | ||
| "rt_detr_v2": "rt_detr", | ||
| "pp_doclayout_v2": "rt_detr", | ||
| "pp_doclayout_v3": "rt_detr", | ||
| "paligemma": "llava", | ||
| "ayavision": "llava", | ||
| "fuyu": "llava", | ||
| "gotocr2": "llava", | ||
| "gemma3": "llava", | ||
| "internvl": "llava", | ||
| "llava_next": "llava", | ||
| "llava_next_video": "llava", | ||
| "llava_onevision": "llava", | ||
| "vipllava": "llava", | ||
| "video_llava": "llava", | ||
| "mistral3": "llava", | ||
| "mllama": "llava", | ||
| "qwen2_5_vl": "qwen2_vl", | ||
| "sam3_tracker_video": "sam3_tracker", | ||
| } | ||
|
|
||
|
|
||
| def _build_checkpoint_conversion_mapping(): | ||
| mapping = { | ||
| "llava": [ | ||
| WeightRenaming(source_patterns=r"language_model.model", target_patterns="language_model"), | ||
| WeightRenaming(source_patterns=r"language_model.lm_head", target_patterns="lm_head"), | ||
| ], | ||
| "colpali": [ | ||
| WeightRenaming(source_patterns=r"vlm(?!\.model)", target_patterns="vlm.model"), | ||
| ], | ||
| "emu3": [ | ||
| WeightRenaming(source_patterns=r"text_model.model", target_patterns="text_model"), | ||
| WeightRenaming(source_patterns=r"text_model.lm_head", target_patterns="lm_head"), | ||
| ], | ||
| "paddleocrvl": [ | ||
| WeightRenaming( | ||
| source_patterns=r"^model(?!(\.visual|\.projector|\.language_model))", | ||
| target_patterns="model.language_model", | ||
| ), | ||
| ], | ||
| "qwen2_vl": [ | ||
| WeightRenaming( | ||
| source_patterns=r"(?<!_)model(?!\.(language_model|visual))", target_patterns="model.language_model" | ||
| ), | ||
| ], | ||
| "colqwen2": [ | ||
| WeightRenaming( | ||
| source_patterns=r"vlm.model(?!\.(language_model|visual))", | ||
| target_patterns="vlm.model.language_model", | ||
| ), | ||
| ], | ||
| "gemma3n_text": [ | ||
| WeightRenaming(source_patterns=r"^model.language_model", target_patterns="model"), | ||
| ], | ||
| "timesfm2_5": [ | ||
| WeightRenaming("ff0", "fc1"), | ||
| WeightRenaming("ff1", "fc2"), | ||
|
|
@@ -79,15 +124,16 @@ def _build_checkpoint_conversion_mapping(): | |
| "qwen3_5_text": [ | ||
| WeightRenaming(source_patterns=r"^model.language_model", target_patterns="model"), | ||
| ], | ||
| "t5gemma2": [ | ||
| WeightRenaming(r"(?<!vision_model\.)encoder.embed_tokens.", "encoder.text_model.embed_tokens."), | ||
| WeightRenaming(r"(?<!vision_model\.)encoder.norm.", "encoder.text_model.norm."), | ||
| WeightRenaming(r"(?<!vision_model\.)encoder.layers.", "encoder.text_model.layers."), | ||
| "sam3_tracker": [ | ||
| WeightRenaming( | ||
| source_patterns=r"detector_model.vision_encoder.backbone.", target_patterns="vision_encoder.backbone" | ||
| ), | ||
| WeightRenaming(source_patterns=r"tracker_neck.", target_patterns="vision_encoder.neck."), | ||
| ], | ||
| "t5gemma2_encoder": [ | ||
| WeightRenaming("^embed_tokens.", "text_model.embed_tokens."), | ||
| WeightRenaming("^norm.", "text_model.norm."), | ||
| WeightRenaming("^layers.", "text_model.layers."), | ||
| WeightRenaming(r"(?<!decoder)(?<!text_model)\.embed_tokens\.", "text_model.embed_tokens."), | ||
| WeightRenaming(r"(?<!decoder)(?<!text_model)\.norm\.", "text_model.norm."), | ||
| WeightRenaming(r"(?<!vision_model.encoder)(?<!decoder)(?<!text_model)\.layers.", "text_model.layers."), | ||
|
zucchini-nlp marked this conversation as resolved.
Outdated
|
||
| ], | ||
| "gpt_oss": [ | ||
| # NOTE: These converters are only applied if the model is being loaded from pre-dequantized checkpoint. | ||
|
|
@@ -205,6 +251,7 @@ def _build_checkpoint_conversion_mapping(): | |
| # language model | ||
| WeightRenaming(r"(?<!language_model\.)embed_tokens", "language_model.embed_tokens"), | ||
| WeightRenaming(r"(?<!language_model\.)layers", "language_model.layers"), | ||
| WeightRenaming(r"(?<!_)(?<!\w)norm\.", "language_model.norm."), | ||
| WeightConverter( | ||
| source_patterns="mlp.gate.weight_1", | ||
| target_patterns="mlp.vision_moe.gate.weight", | ||
|
|
@@ -322,14 +369,6 @@ def _build_checkpoint_conversion_mapping(): | |
| operations=[MergeModulelist(dim=0)], | ||
| ), | ||
| ], | ||
| "timm_wrapper": [ | ||
| # Simply add the prefix `timm_model` | ||
| # TODO: Would be probably much cleaner with a `add_prefix` argument in WeightRenaming | ||
| WeightRenaming( | ||
| source_patterns=r"(.+)", | ||
| target_patterns=r"timm_model.\1", | ||
| ) | ||
|
zucchini-nlp marked this conversation as resolved.
Outdated
|
||
| ], | ||
|
zucchini-nlp marked this conversation as resolved.
Outdated
|
||
| "legacy": [ | ||
| WeightRenaming( | ||
| source_patterns="LayerNorm.gamma", | ||
|
|
@@ -424,34 +463,7 @@ def register_checkpoint_conversion_mapping( | |
|
|
||
|
|
||
| # DO NOT MODIFY, KEPT FOR BC ONLY | ||
| VLMS = [ | ||
| "aria", | ||
| "ayavision", | ||
| "colpali", | ||
| "emu3", | ||
| "fuyu", | ||
| "gotocr2", | ||
| "gemma3", | ||
| "internvl", | ||
| "llava", # all llava prefixed models fall under this check | ||
| "mistral3", | ||
| "mllama", | ||
| "paligemma", | ||
| "shieldgemma2", | ||
| "qwen2vl", | ||
| "qwen2_5_vl", | ||
| "videollava", | ||
| "vipllava", | ||
| "sam3_video", | ||
| "sam3", | ||
| "sam3_tracker", | ||
| "sam3_tracker_video", | ||
| "paddleocrvl", | ||
| # NOTE: Slightly different from `model_type` (to follow naming conventions in vllm/sglang) | ||
| "ernie4_5_vlmoe", | ||
| "ernie4_5_vl_moe", # BC alias | ||
| "detr", | ||
| ] | ||
| VLMS = ["detr"] | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, what's the issue with detr :D
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. my laziness haha, detr models have a huge block of conversions but used only for specific task classes
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i'll try to move it, if tests don't start complaining |
||
|
|
||
|
|
||
| def get_model_conversion_mapping( | ||
|
|
@@ -479,7 +491,6 @@ def get_model_conversion_mapping( | |
| for k, v in model._checkpoint_conversion_mapping.items() | ||
| ] | ||
|
|
||
| # TODO: should be checked recursively on submodels!! | ||
| model_type = getattr(model.config, "model_type", None) | ||
| if model_type is not None: | ||
| model_specific_conversions = get_checkpoint_conversion_mapping(model_type) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.