Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
2c1d904
split out from timm PR
zucchini-nlp Feb 26, 2026
0caafab
all other VLMs
zucchini-nlp Feb 26, 2026
edb4d27
timm backbone is not here
zucchini-nlp Feb 26, 2026
cbbecf4
oops, extra key is breaking eveerything
zucchini-nlp Feb 26, 2026
1fd89ae
.
zucchini-nlp Feb 26, 2026
0a4a829
this test
zucchini-nlp Feb 26, 2026
af57e2e
maybe
zucchini-nlp Feb 26, 2026
ea3ac63
Merge branch 'main' into convert-weights-recursive
zucchini-nlp Feb 26, 2026
fbda051
fix missing keys when loading from hub
zucchini-nlp Feb 27, 2026
88f18be
now fix fast tests
zucchini-nlp Feb 27, 2026
6eb0e13
Merge branch 'main' into convert-weights-recursive
zucchini-nlp Feb 27, 2026
b9f29ab
merge gone wrong
zucchini-nlp Feb 27, 2026
d452817
fix repo
zucchini-nlp Feb 27, 2026
500e96b
refine the regex again!
zucchini-nlp Feb 27, 2026
3fcf527
Merge branch 'main' into convert-weights-recursive
zucchini-nlp Mar 3, 2026
23a3e13
close the bracket
zucchini-nlp Mar 3, 2026
42e85f0
Merge branch 'main' into convert-weights-recursive
zucchini-nlp Mar 10, 2026
e85b103
Apply suggestions from code review
zucchini-nlp Mar 11, 2026
ce82bcc
merge main
zucchini-nlp Mar 12, 2026
d5ab4fd
main
zucchini-nlp Mar 17, 2026
dcc95d5
revert unrelated
zucchini-nlp Mar 17, 2026
f4b5888
!
zucchini-nlp Mar 17, 2026
215da83
revert more
zucchini-nlp Mar 17, 2026
370feb6
add submodule prefix when recursing
zucchini-nlp Mar 17, 2026
3c6a23f
Merge branch 'main' into convert-weights-recursive
zucchini-nlp Mar 17, 2026
b3559f7
i'll need to fix maskformer later
zucchini-nlp Mar 17, 2026
960c716
dont duplicate the same pattern twice
zucchini-nlp Mar 18, 2026
52560f6
Merge branch 'main' into convert-weights-recursive
zucchini-nlp Mar 18, 2026
17294cc
fix modular
zucchini-nlp Mar 18, 2026
7503c12
detr
zucchini-nlp Mar 18, 2026
aac4bba
colpali isn't working still!
zucchini-nlp Mar 19, 2026
1c770a1
oke, so this can be fine for now
zucchini-nlp Mar 19, 2026
6780970
!
zucchini-nlp Mar 19, 2026
a1220c9
revert
zucchini-nlp Mar 19, 2026
5a68bf7
dot lost in regex and comments
zucchini-nlp Mar 20, 2026
d10cb69
Merge branch 'main' into convert-weights-recursive
zucchini-nlp Mar 20, 2026
abb8001
timm wrapper is weird
zucchini-nlp Mar 20, 2026
503c206
skip these, timm wrapper
zucchini-nlp Mar 23, 2026
f92c063
Merge branch 'main' into convert-weights-recursive
zucchini-nlp Mar 23, 2026
ca68663
bye bye timm
zucchini-nlp Mar 23, 2026
4c65203
make repo check happy
zucchini-nlp Mar 23, 2026
cc19ab9
Revert "bye bye timm"
zucchini-nlp Mar 24, 2026
7e3d40e
love timm!
zucchini-nlp Mar 24, 2026
29e600a
Merge branch 'main' into convert-weights-recursive
zucchini-nlp Mar 25, 2026
df86ff7
Apply repo consistency fixes
github-actions[bot] Mar 25, 2026
ebcb04f
oke, the bot can't fix it so here we go
zucchini-nlp Mar 25, 2026
99f0a05
Merge branch 'main' into convert-weights-recursive
zucchini-nlp Mar 25, 2026
58e0d1c
Merge branch 'main' into convert-weights-recursive
zucchini-nlp Mar 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 56 additions & 45 deletions src/transformers/conversion_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,19 +55,64 @@
"qwen3_omni_moe": "qwen2_moe",
"qwen3_omni_moe_thinker": "qwen2_moe",
"qwen3_next": "qwen2_moe",
"qwen3_5_moe": "qwen2_moe",
Comment thread
zucchini-nlp marked this conversation as resolved.
"hunyuan_v1_moe": "qwen2_moe",
"flex_olmo": "qwen2_moe",
"olmoe": "qwen2_moe",
"exaone_moe": "qwen2_moe",
"rt_detr_v2": "rt_detr",
"pp_doclayout_v2": "rt_detr",
"pp_doclayout_v3": "rt_detr",
"paligemma": "llava",
"ayavision": "llava",
"fuyu": "llava",
"gotocr2": "llava",
"gemma3": "llava",
"internvl": "llava",
"llava_next": "llava",
"llava_next_video": "llava",
"llava_onevision": "llava",
"vipllava": "llava",
"video_llava": "llava",
"mistral3": "llava",
"mllama": "llava",
"qwen2_5_vl": "qwen2_vl",
"sam3_tracker_video": "sam3_tracker",
}


def _build_checkpoint_conversion_mapping():
mapping = {
"llava": [
WeightRenaming(source_patterns=r"language_model.model", target_patterns="language_model"),
WeightRenaming(source_patterns=r"language_model.lm_head", target_patterns="lm_head"),
],
"colpali": [
WeightRenaming(source_patterns=r"vlm(?!\.model)", target_patterns="vlm.model"),
],
"emu3": [
WeightRenaming(source_patterns=r"text_model.model", target_patterns="text_model"),
WeightRenaming(source_patterns=r"text_model.lm_head", target_patterns="lm_head"),
],
"paddleocrvl": [
WeightRenaming(
source_patterns=r"^model(?!(\.visual|\.projector|\.language_model))",
target_patterns="model.language_model",
),
],
"qwen2_vl": [
WeightRenaming(
source_patterns=r"(?<!_)model(?!\.(language_model|visual))", target_patterns="model.language_model"
),
],
"colqwen2": [
WeightRenaming(
source_patterns=r"vlm.model(?!\.(language_model|visual))",
target_patterns="vlm.model.language_model",
),
],
"gemma3n_text": [
WeightRenaming(source_patterns=r"^model.language_model", target_patterns="model"),
],
"timesfm2_5": [
WeightRenaming("ff0", "fc1"),
WeightRenaming("ff1", "fc2"),
Expand All @@ -79,15 +124,16 @@ def _build_checkpoint_conversion_mapping():
"qwen3_5_text": [
WeightRenaming(source_patterns=r"^model.language_model", target_patterns="model"),
],
"t5gemma2": [
WeightRenaming(r"(?<!vision_model\.)encoder.embed_tokens.", "encoder.text_model.embed_tokens."),
WeightRenaming(r"(?<!vision_model\.)encoder.norm.", "encoder.text_model.norm."),
WeightRenaming(r"(?<!vision_model\.)encoder.layers.", "encoder.text_model.layers."),
"sam3_tracker": [
WeightRenaming(
source_patterns=r"detector_model.vision_encoder.backbone.", target_patterns="vision_encoder.backbone"
),
WeightRenaming(source_patterns=r"tracker_neck.", target_patterns="vision_encoder.neck."),
],
"t5gemma2_encoder": [
WeightRenaming("^embed_tokens.", "text_model.embed_tokens."),
WeightRenaming("^norm.", "text_model.norm."),
WeightRenaming("^layers.", "text_model.layers."),
WeightRenaming(r"(?<!decoder)(?<!text_model)\.embed_tokens\.", "text_model.embed_tokens."),
WeightRenaming(r"(?<!decoder)(?<!text_model)\.norm\.", "text_model.norm."),
WeightRenaming(r"(?<!vision_model.encoder)(?<!decoder)(?<!text_model)\.layers.", "text_model.layers."),
Comment thread
zucchini-nlp marked this conversation as resolved.
Outdated
],
"gpt_oss": [
# NOTE: These converters are only applied if the model is being loaded from pre-dequantized checkpoint.
Expand Down Expand Up @@ -205,6 +251,7 @@ def _build_checkpoint_conversion_mapping():
# language model
WeightRenaming(r"(?<!language_model\.)embed_tokens", "language_model.embed_tokens"),
WeightRenaming(r"(?<!language_model\.)layers", "language_model.layers"),
WeightRenaming(r"(?<!_)(?<!\w)norm\.", "language_model.norm."),
WeightConverter(
source_patterns="mlp.gate.weight_1",
target_patterns="mlp.vision_moe.gate.weight",
Expand Down Expand Up @@ -322,14 +369,6 @@ def _build_checkpoint_conversion_mapping():
operations=[MergeModulelist(dim=0)],
),
],
"timm_wrapper": [
# Simply add the prefix `timm_model`
# TODO: Would be probably much cleaner with a `add_prefix` argument in WeightRenaming
WeightRenaming(
source_patterns=r"(.+)",
target_patterns=r"timm_model.\1",
)
Comment thread
zucchini-nlp marked this conversation as resolved.
Outdated
],
Comment thread
zucchini-nlp marked this conversation as resolved.
Outdated
"legacy": [
WeightRenaming(
source_patterns="LayerNorm.gamma",
Expand Down Expand Up @@ -424,34 +463,7 @@ def register_checkpoint_conversion_mapping(


# DO NOT MODIFY, KEPT FOR BC ONLY
VLMS = [
"aria",
"ayavision",
"colpali",
"emu3",
"fuyu",
"gotocr2",
"gemma3",
"internvl",
"llava", # all llava prefixed models fall under this check
"mistral3",
"mllama",
"paligemma",
"shieldgemma2",
"qwen2vl",
"qwen2_5_vl",
"videollava",
"vipllava",
"sam3_video",
"sam3",
"sam3_tracker",
"sam3_tracker_video",
"paddleocrvl",
# NOTE: Slightly different from `model_type` (to follow naming conventions in vllm/sglang)
"ernie4_5_vlmoe",
"ernie4_5_vl_moe", # BC alias
"detr",
]
VLMS = ["detr"]

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, what's the issue with detr :D

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my laziness haha, detr models have a huge block of conversions but used only for specific task classes

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'll try to move it, if tests don't start complaining



def get_model_conversion_mapping(
Expand Down Expand Up @@ -479,7 +491,6 @@ def get_model_conversion_mapping(
for k, v in model._checkpoint_conversion_mapping.items()
]

# TODO: should be checked recursively on submodels!!
model_type = getattr(model.config, "model_type", None)
if model_type is not None:
model_specific_conversions = get_checkpoint_conversion_mapping(model_type)
Expand Down
41 changes: 28 additions & 13 deletions src/transformers/core_model_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ def reverse_op(self) -> ConversionOps:
return Force16BytesAlignment()


def process_target_pattern(pattern: str) -> tuple[str, str | None]:
def process_pattern_for_reverse_mapping(target_pattern: str, source_pattern: str = "") -> tuple[str, str | None]:
"""
Process a target pattern for reverse mapping (when targets become sources).

Expand All @@ -498,26 +498,35 @@ def process_target_pattern(pattern: str) -> tuple[str, str | None]:
- Detects capturing groups and replaces them with `\\1` backreference

Args:
pattern: The target pattern to process for reverse mapping.
target_pattern: The target pattern to process for reverse mapping.
source_pattern: The source pattern to process for reverse mapping.

Returns:
A tuple of (processed_pattern, captured_group) where captured_group is
the original capturing group found (e.g., "(encoder|decoder)") or None.
"""
# Some mapping contains `^` to notify start of string when matching -> remove it during reverse mapping
pattern = pattern.removeprefix("^")
if target_pattern.startswith("^"):
source_pattern = f"^{source_pattern}"
target_pattern = target_pattern.removeprefix("^")
Comment thread
zucchini-nlp marked this conversation as resolved.
Outdated

# Some mapping contains `$` to notify end of string when matching -> remove it during reverse mapping
pattern = pattern.removesuffix("$")
if target_pattern.endswith("$"):
source_pattern = f"{source_pattern}$"
target_pattern = target_pattern.removesuffix("$")

Comment thread
zucchini-nlp marked this conversation as resolved.
Outdated
# Remove negative lookahead/behind if any. This is ugly but needed for reverse mapping of
# Qwen2.5, Sam3, Ernie4.5 VL MoE!
pattern = re.sub(r"\(\?.+\)", "", pattern)
target_pattern = re.sub(r"\(\?.+?\)?\)", "", target_pattern)
Comment thread
zucchini-nlp marked this conversation as resolved.
Outdated
# Remove the backslash for literal dots
target_pattern = target_pattern.replace(r"\.", ".")
# Allow capturing groups in patterns, i.e. to add/remove a prefix to all keys (e.g. timm_wrapper, sam3)
capturing_group_match = re.search(r"\(.+?\)", pattern)
capturing_group_match = re.search(r"\(.+?\)", target_pattern)
captured_group = None
if capturing_group_match:
captured_group = capturing_group_match.group(0)
pattern = pattern.replace(captured_group, r"\1", 1)
return pattern, captured_group
target_pattern = target_pattern.replace(captured_group, r"\1", 1)
return target_pattern, captured_group, source_pattern


@dataclass(slots=True)
Expand Down Expand Up @@ -551,8 +560,16 @@ def __post_init__(self):
# Process target_patterns: detect capturing groups and replace with \1
# Store the original capturing group patterns for reverse mapping
target_capturing_groups: list[str] = []
for i, pattern in enumerate(self.target_patterns):
self.target_patterns[i], captured_group = process_target_pattern(pattern)

is_weight_renaming = self.__class__.__name__ == "WeightRenaming"
for i, target_pattern in enumerate(self.target_patterns):
if is_weight_renaming:
source_pattern = self.source_patterns[i]
self.target_patterns[i], captured_group, self.source_patterns[i] = process_pattern_for_reverse_mapping(
target_pattern, source_pattern
)
else:
self.target_patterns[i], captured_group, _ = process_pattern_for_reverse_mapping(target_pattern)
Comment thread
zucchini-nlp marked this conversation as resolved.
Outdated
if captured_group is not None:
target_capturing_groups.append(captured_group)

Expand Down Expand Up @@ -1267,10 +1284,8 @@ def revert_weight_conversion(model: PreTrainedModel, state_dict: dict[str, torch
# In this case, the model was not created with `from_pretrained` -> let's check if it's in the hardcoded
# mappings, and recreate the mapping from there if it is
if weight_conversions is None:
from .conversion_mapping import get_model_conversion_mapping

# Do not resave with the legacy renaming, if present
weight_conversions = get_model_conversion_mapping(model, add_legacy=False)
weight_conversions = model.get_weight_conversions_recursively(add_legacy=False)
Comment thread
zucchini-nlp marked this conversation as resolved.
Outdated
weight_conversions = weight_conversions if len(weight_conversions) > 0 else None

# We did not find any operations to perform -> quick escape
Expand Down
3 changes: 1 addition & 2 deletions src/transformers/integrations/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from ..conversion_mapping import (
_MODEL_TO_CONVERSION_PATTERN,
get_checkpoint_conversion_mapping,
get_model_conversion_mapping,
)
from ..core_model_loading import (
Concatenate,
Expand Down Expand Up @@ -519,7 +518,7 @@ def load_adapter(
**load_config.download_kwargs,
)

weight_conversions = get_model_conversion_mapping(self)
weight_conversions = self.get_weight_conversions_recursively()
peft_config = convert_peft_config_for_transformers(peft_config, model=self, conversions=weight_conversions)

if hasattr(peft_config, "inference_mode"):
Expand Down
18 changes: 16 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4110,8 +4110,8 @@ def from_pretrained(
# instantiated model, as the flags can be modified by instances sometimes)
dtype_plan = model._get_dtype_plan(dtype)

# Obtain the weight conversion mapping for this model if any are registered
weight_conversions = get_model_conversion_mapping(model, key_mapping, hf_quantizer)
# Obtain the weight conversion mapping for this model if any are registered and appy to all submodels recursively
weight_conversions = model.get_weight_conversions_recursively(key_mapping, hf_quantizer)

if _torch_distributed_available and device_mesh is not None: # add hooks to nn.Modules: no weights
model = distribute_model(model, tp_plan, distributed_config, device_mesh, tp_size)
Expand Down Expand Up @@ -4309,6 +4309,20 @@ def _finalize_model_loading(

return loading_info

def get_weight_conversions_recursively(self, key_mapping=None, hf_quantizer=None, add_legacy=True):
Comment thread
zucchini-nlp marked this conversation as resolved.
Outdated
conversions = []
conversions.extend(get_model_conversion_mapping(self, key_mapping, hf_quantizer, add_legacy))

for submodule in self.children():
Comment thread
zucchini-nlp marked this conversation as resolved.
Outdated
if (
submodule is not self
and isinstance(submodule, PreTrainedModel)
and submodule.config.__class__ != self.config.__class__
):
conversions.extend(get_model_conversion_mapping(submodule, key_mapping, hf_quantizer, add_legacy))
conversions.extend(submodule.get_weight_conversions_recursively(key_mapping, hf_quantizer, add_legacy))
return conversions
Comment thread
zucchini-nlp marked this conversation as resolved.
Outdated

def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False):
module_keys = {".".join(key.split(".")[:-1]) for key in names}

Expand Down
10 changes: 0 additions & 10 deletions src/transformers/models/aria/modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,10 +885,6 @@ class AriaModelOutputWithPast(BaseModelOutputWithPast):
"""
)
class AriaModel(AriaPreTrainedModel):
_checkpoint_conversion_mapping = {
r"^language_model.model": "language_model",
}

def __init__(self, config: AriaConfig):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config.vision_config)
Expand Down Expand Up @@ -1031,12 +1027,6 @@ def _create_patch_attention_mask(self, pixel_mask):
"""
)
class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = {
r"^language_model.model": "model.language_model",
r"^vision_tower": "model.vision_tower",
r"^multi_modal_projector": "model.multi_modal_projector",
r"^language_model.lm_head": "lm_head",
}
_tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}

def __init__(self, config: AriaConfig):
Expand Down
10 changes: 0 additions & 10 deletions src/transformers/models/aya_vision/modeling_aya_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,6 @@ class AyaVisionModelOutputWithPast(BaseModelOutputWithPast):
"""
)
class AyaVisionModel(AyaVisionPreTrainedModel):
_checkpoint_conversion_mapping = {
r"^language_model.model": "language_model",
}

def __init__(self, config: AyaVisionConfig):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config.vision_config)
Expand Down Expand Up @@ -302,12 +298,6 @@ def forward(
"""
)
class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = {
r"^language_model.model": "model.language_model",
r"^vision_tower": "model.vision_tower",
r"^multi_modal_projector": "model.multi_modal_projector",
r"^language_model.lm_head": "lm_head",
}
_tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}

def __init__(self, config: AyaVisionConfig):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,6 @@ class Cohere2VisionPreTrainedModel(PreTrainedModel):
"""
)
class Cohere2VisionModel(Cohere2VisionPreTrainedModel):
_checkpoint_conversion_mapping = {}

def __init__(self, config: Cohere2VisionConfig):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config.vision_config)
Expand Down Expand Up @@ -254,7 +252,6 @@ def forward(
"""
)
class Cohere2VisionForConditionalGeneration(Cohere2VisionPreTrainedModel, GenerationMixin):
_checkpoint_conversion_mapping = {}
_tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}

def __init__(self, config: Cohere2VisionConfig):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,6 @@ class Cohere2VisionPreTrainedModel(AyaVisionPreTrainedModel):


class Cohere2VisionModel(AyaVisionModel):
_checkpoint_conversion_mapping = {}

@can_return_tuple
@auto_docstring(
custom_intro="Obtains image last hidden states from the vision tower and apply multimodal projection."
Expand Down Expand Up @@ -154,8 +152,6 @@ def forward(


class Cohere2VisionForConditionalGeneration(AyaVisionForConditionalGeneration):
_checkpoint_conversion_mapping = {}

@auto_docstring
def get_image_features(
self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,6 @@ class ColModernVBertForRetrievalOutput(ModelOutput):
"""
)
class ColModernVBertForRetrieval(ColModernVBertPreTrainedModel):
_checkpoint_conversion_mapping = {}

def __init__(self, config: ColModernVBertConfig):
super().__init__(config)
self.config = config
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,8 +387,6 @@ class ColModernVBertForRetrievalOutput(ModelOutput):
"""
)
class ColModernVBertForRetrieval(ColPaliForRetrieval):
_checkpoint_conversion_mapping = {}

def __init__(self, config: ColModernVBertConfig):
super().__init__(config)
self.vlm = AutoModel.from_config(config.vlm_config)
Expand Down
Loading
Loading