-
Notifications
You must be signed in to change notification settings - Fork 32.9k
Dynamic weight conversion is recursive #44300
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2c1d904
0caafab
edb4d27
cbbecf4
1fd89ae
0a4a829
af57e2e
ea3ac63
fbda051
88f18be
6eb0e13
b9f29ab
d452817
500e96b
3fcf527
23a3e13
42e85f0
e85b103
ce82bcc
d5ab4fd
dcc95d5
f4b5888
215da83
370feb6
3c6a23f
b3559f7
960c716
52560f6
17294cc
7503c12
aac4bba
1c770a1
6780970
a1220c9
5a68bf7
d10cb69
abb8001
503c206
f92c063
ca68663
4c65203
cc19ab9
7e3d40e
29e600a
df86ff7
ebcb04f
99f0a05
58e0d1c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -55,7 +55,6 @@ | |
| "qwen3_omni_moe": "qwen2_moe", | ||
| "qwen3_omni_moe_thinker": "qwen2_moe", | ||
| "qwen3_next": "qwen2_moe", | ||
| "qwen3_5_moe": "qwen2_moe", | ||
| "hunyuan_v1_moe": "qwen2_moe", | ||
| "flex_olmo": "qwen2_moe", | ||
| "olmoe": "qwen2_moe", | ||
|
|
@@ -91,7 +90,6 @@ def _build_checkpoint_conversion_mapping(): | |
| ], | ||
| "colpali": [ | ||
| WeightRenaming(source_patterns=r"vlm(?!\.model)", target_patterns="vlm.model"), | ||
| WeightRenaming(source_patterns=r"language_model.model", target_patterns="language_model"), | ||
| ], | ||
| "emu3": [ | ||
| WeightRenaming(source_patterns=r"text_model.model", target_patterns="text_model"), | ||
|
|
@@ -109,20 +107,16 @@ def _build_checkpoint_conversion_mapping(): | |
| source_patterns=r"(?<!_)model(?!\.(language_model|visual))", target_patterns="model.language_model" | ||
| ), | ||
| ], | ||
| "colqwen2": [ | ||
| WeightRenaming( | ||
| source_patterns=r"vlm.model(?!\.(language_model|visual))", | ||
| target_patterns="vlm.model.language_model", | ||
| ), | ||
| ], | ||
| "gemma3n_text": [ | ||
| WeightRenaming(source_patterns=r"^model.language_model", target_patterns="model"), | ||
| ], | ||
| "timm_wrapper": [ | ||
| # Simply add the prefix `timm_model`. Similar to `base_model_prefix` but also removes prefix | ||
| # when saving.TODO: Would be probably much cleaner with a `add_prefix` argument in WeightRenaming | ||
| # when saving. TODO: Would be probably much cleaner with a `add_prefix` argument in WeightRenaming | ||
| # Note: we don't add `timm_model` when it is part of a bigger VLM, because they already have `timm_model` | ||
| # saved in state dict keys. Thus the look behind check. Should be fixed by proper `add_prefix`! | ||
| WeightRenaming( | ||
| source_patterns=r"(.+)", | ||
| source_patterns=r"^(?!(?:model\.|backbone\.|tower\.))(.+)$", | ||
| target_patterns=r"timm_model.\1", | ||
| ) | ||
| ], | ||
|
|
@@ -147,7 +141,6 @@ def _build_checkpoint_conversion_mapping(): | |
| target_patterns="model.vlm.language_model.embed_tokens", | ||
| ), | ||
| ], | ||
| "chmv2": [WeightRenaming(r"backbone.layer.", r"backbone.model.layer.")], | ||
| "dinov3_convnext": [WeightRenaming(r"(?<!model\.)stages", r"model.stages")], | ||
| "dinov3_vit": [WeightRenaming(r"(?<!model\.)layer.", r"model.layer.")], | ||
| "timesfm2_5": [ | ||
|
|
@@ -161,21 +154,16 @@ def _build_checkpoint_conversion_mapping(): | |
| "qwen3_5_text": [ | ||
| WeightRenaming(source_patterns=r"^model.language_model", target_patterns="model"), | ||
| ], | ||
| "t5gemma2": [ | ||
| WeightRenaming(r"(?<!vision_model\.)encoder.embed_tokens.", "encoder.text_model.embed_tokens."), | ||
| WeightRenaming(r"(?<!vision_model\.)encoder.norm.", "encoder.text_model.norm."), | ||
| WeightRenaming(r"(?<!vision_model\.)encoder.layers.", "encoder.text_model.layers."), | ||
| ], | ||
| "sam3_tracker": [ | ||
| WeightRenaming( | ||
| source_patterns=r"detector_model.vision_encoder.backbone.", target_patterns="vision_encoder.backbone." | ||
| ), | ||
| WeightRenaming(source_patterns=r"tracker_neck.", target_patterns="vision_encoder.neck."), | ||
| ], | ||
| "t5gemma2_encoder": [ | ||
| WeightRenaming("^embed_tokens.", "text_model.embed_tokens."), | ||
| WeightRenaming("^norm.", "text_model.norm."), | ||
| WeightRenaming("^layers.", "text_model.layers."), | ||
| WeightRenaming(r"(?<!decoder\.)(?<!text_model\.)embed_tokens\.", "text_model.embed_tokens."), | ||
| WeightRenaming(r"(?<!decoder\.)(?<!text_model\.)(?<!layer)(?<!_)norm\.", "text_model.norm."), | ||
| WeightRenaming(r"(?<!vision_model.encoder\.)(?<!decoder\.)(?<!text_model\.)layers.", "text_model.layers."), | ||
| ], | ||
| "mixtral": [ | ||
| WeightRenaming(".block_sparse_moe.", ".mlp."), | ||
|
|
@@ -320,6 +308,24 @@ def _build_checkpoint_conversion_mapping(): | |
| WeightRenaming("out_proj", "o_proj"), | ||
| WeightRenaming(r"layers.(\d+).fc1", r"layers.\1.mlp.fc1"), | ||
| WeightRenaming(r"layers.(\d+).fc2", r"layers.\1.mlp.fc2"), | ||
| # `DetrForSegmentation` | ||
| WeightRenaming("bbox_attention.q_linear", "bbox_attention.q_proj"), | ||
| WeightRenaming("bbox_attention.k_linear", "bbox_attention.k_proj"), | ||
| # Mask head refactor | ||
|
Comment on lines
+311
to
+314
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this was straightforward, so I just moved it here and deleted |
||
| WeightRenaming("mask_head.lay1", "mask_head.conv1.conv"), | ||
| WeightRenaming("mask_head.gn1", "mask_head.conv1.norm"), | ||
| WeightRenaming("mask_head.lay2", "mask_head.conv2.conv"), | ||
| WeightRenaming("mask_head.gn2", "mask_head.conv2.norm"), | ||
| WeightRenaming("mask_head.adapter1", "mask_head.fpn_stages.0.fpn_adapter"), | ||
| WeightRenaming("mask_head.lay3", "mask_head.fpn_stages.0.refine.conv"), | ||
| WeightRenaming("mask_head.gn3", "mask_head.fpn_stages.0.refine.norm"), | ||
| WeightRenaming("mask_head.adapter2", "mask_head.fpn_stages.1.fpn_adapter"), | ||
| WeightRenaming("mask_head.lay4", "mask_head.fpn_stages.1.refine.conv"), | ||
| WeightRenaming("mask_head.gn4", "mask_head.fpn_stages.1.refine.norm"), | ||
| WeightRenaming("mask_head.adapter3", "mask_head.fpn_stages.2.fpn_adapter"), | ||
| WeightRenaming("mask_head.lay5", "mask_head.fpn_stages.2.refine.conv"), | ||
| WeightRenaming("mask_head.gn5", "mask_head.fpn_stages.2.refine.norm"), | ||
| WeightRenaming("mask_head.out_lay", "mask_head.output_conv"), | ||
| ], | ||
| "rt_detr": [ | ||
| WeightRenaming("out_proj", "o_proj"), | ||
|
|
@@ -348,6 +354,24 @@ def _build_checkpoint_conversion_mapping(): | |
| WeightRenaming( | ||
| r"decoder.layers.(\d+).ca_qpos_sine_proj", r"decoder.layers.\1.encoder_attn.q_pos_sine_proj" | ||
| ), | ||
| # The rest of patterns are used only in `ConditionalDetrForSegmentation` | ||
| WeightRenaming("bbox_attention.q_linear", "bbox_attention.q_proj"), | ||
| WeightRenaming("bbox_attention.k_linear", "bbox_attention.k_proj"), | ||
| # Mask head refactor | ||
| WeightRenaming("mask_head.lay1", "mask_head.conv1.conv"), | ||
| WeightRenaming("mask_head.gn1", "mask_head.conv1.norm"), | ||
| WeightRenaming("mask_head.lay2", "mask_head.conv2.conv"), | ||
| WeightRenaming("mask_head.gn2", "mask_head.conv2.norm"), | ||
| WeightRenaming("mask_head.adapter1", "mask_head.fpn_stages.0.fpn_adapter"), | ||
| WeightRenaming("mask_head.lay3", "mask_head.fpn_stages.0.refine.conv"), | ||
| WeightRenaming("mask_head.gn3", "mask_head.fpn_stages.0.refine.norm"), | ||
| WeightRenaming("mask_head.adapter2", "mask_head.fpn_stages.1.fpn_adapter"), | ||
| WeightRenaming("mask_head.lay4", "mask_head.fpn_stages.1.refine.conv"), | ||
| WeightRenaming("mask_head.gn4", "mask_head.fpn_stages.1.refine.norm"), | ||
| WeightRenaming("mask_head.adapter3", "mask_head.fpn_stages.2.fpn_adapter"), | ||
| WeightRenaming("mask_head.lay5", "mask_head.fpn_stages.2.refine.conv"), | ||
| WeightRenaming("mask_head.gn5", "mask_head.fpn_stages.2.refine.norm"), | ||
| WeightRenaming("mask_head.out_lay", "mask_head.output_conv"), | ||
| ], | ||
| "deformable_detr": [ | ||
| WeightRenaming("backbone.conv_encoder", "backbone"), | ||
|
|
@@ -503,8 +527,12 @@ def register_checkpoint_conversion_mapping( | |
| _checkpoint_conversion_mapping_cache[model_type] = mapping | ||
|
|
||
|
|
||
| # DO NOT MODIFY, KEPT FOR BC ONLY | ||
| VLMS = ["detr"] | ||
| def extract_weight_conversions_for_model(model: PreTrainedModel) -> list[WeightConverter | WeightRenaming] | None: | ||
| model_type = getattr(model.config, "model_type", None) | ||
| if model_type is not None: | ||
| model_specific_conversions = get_checkpoint_conversion_mapping(model_type) | ||
| return model_specific_conversions | ||
| return None | ||
|
|
||
|
|
||
| def get_model_conversion_mapping( | ||
|
|
@@ -517,28 +545,35 @@ def get_model_conversion_mapping( | |
| For a given `model`, obtain the weight conversion mapping if any are registered either as a simple renaming | ||
| `_checkpoint_conversion_mapping` class argument, or in the general WeightConverter mapping. | ||
| """ | ||
| # Lazy import to avoid circular import issues | ||
| from .modeling_utils import PreTrainedModel | ||
|
|
||
| # note: this function is used in PEFT, so changing the API requires coordination | ||
| weight_conversions = [] | ||
|
|
||
| # Load models with explicit, user-provided key mapping | ||
| if key_mapping is not None: | ||
| weight_conversions = [WeightRenaming(source_patterns=k, target_patterns=v) for k, v in key_mapping.items()] | ||
| elif any( | ||
| allowed_name in class_name.__name__.lower() | ||
| for class_name in model.__class__.__mro__[:-1] | ||
| for allowed_name in VLMS | ||
| ): | ||
| weight_conversions = [ | ||
| WeightRenaming(source_patterns=k, target_patterns=v) | ||
| for k, v in model._checkpoint_conversion_mapping.items() | ||
| ] | ||
|
|
||
| # TODO: should be checked recursively on submodels!! | ||
| model_type = getattr(model.config, "model_type", None) | ||
| if model_type is not None: | ||
| model_specific_conversions = get_checkpoint_conversion_mapping(model_type) | ||
| if model_specific_conversions is not None: | ||
| weight_conversions.extend(model_specific_conversions) | ||
| # Model have several `PreTrainedModel` within with the same model type | ||
| # For ex: XForConditionalGeneration -> XModel. We don't want to apply the same | ||
| # conversion pattern twice because of that | ||
|
Comment on lines
+558
to
+560
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we consider linking each regex with certain class in the future? It will also help with those dangling |
||
| seen_model_types = set() | ||
| if (conversions := extract_weight_conversions_for_model(model)) is not None: | ||
| weight_conversions.extend(conversions) | ||
| seen_model_types.add(model.config.model_type) | ||
|
|
||
| # Recurse over submodules and collect all conversions | ||
| for submodule in model.modules(): | ||
| if ( | ||
| submodule is not model | ||
| and isinstance(submodule, PreTrainedModel) | ||
| and submodule.config.model_type not in seen_model_types | ||
| ): | ||
| conversions = extract_weight_conversions_for_model(submodule) | ||
| if conversions is not None: | ||
| weight_conversions.extend(conversions) | ||
| seen_model_types.add(submodule.config.model_type) | ||
|
zucchini-nlp marked this conversation as resolved.
|
||
|
|
||
| if add_legacy: | ||
| weight_conversions.extend(get_checkpoint_conversion_mapping("legacy")) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1435,26 +1435,6 @@ def forward( | |
| """ | ||
| ) | ||
| class DetrForSegmentation(DetrPreTrainedModel): | ||
| _checkpoint_conversion_mapping = { | ||
| "bbox_attention.q_linear": "bbox_attention.q_proj", | ||
| "bbox_attention.k_linear": "bbox_attention.k_proj", | ||
|
Comment on lines
-1438
to
-1440
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. moved to |
||
| # Mask head refactor | ||
| "mask_head.lay1": "mask_head.conv1.conv", | ||
| "mask_head.gn1": "mask_head.conv1.norm", | ||
| "mask_head.lay2": "mask_head.conv2.conv", | ||
| "mask_head.gn2": "mask_head.conv2.norm", | ||
| "mask_head.adapter1": "mask_head.fpn_stages.0.fpn_adapter", | ||
| "mask_head.lay3": "mask_head.fpn_stages.0.refine.conv", | ||
| "mask_head.gn3": "mask_head.fpn_stages.0.refine.norm", | ||
| "mask_head.adapter2": "mask_head.fpn_stages.1.fpn_adapter", | ||
| "mask_head.lay4": "mask_head.fpn_stages.1.refine.conv", | ||
| "mask_head.gn4": "mask_head.fpn_stages.1.refine.norm", | ||
| "mask_head.adapter3": "mask_head.fpn_stages.2.fpn_adapter", | ||
| "mask_head.lay5": "mask_head.fpn_stages.2.refine.conv", | ||
| "mask_head.gn5": "mask_head.fpn_stages.2.refine.norm", | ||
| "mask_head.out_lay": "mask_head.output_conv", | ||
| } | ||
|
|
||
| def __init__(self, config: DetrConfig): | ||
| super().__init__(config) | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.