diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index 83d3790fc2fb..159c503f0f56 100644 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -70,11 +70,56 @@ def _build_checkpoint_conversion_mapping(): operations=[MergeModulelist(dim=0), Concatenate(dim=1)], ), WeightConverter( - source_patterns=["mlp.experts.*.down_proj.weight"], + source_patterns="mlp.experts.*.down_proj.weight", target_patterns="mlp.experts.down_proj", operations=[MergeModulelist(dim=0)], ), ], + "phimoe": [ + WeightConverter( + source_patterns=[ + "mlp.experts.*.w1.weight", + "mlp.experts.*.w3.weight", + ], + target_patterns="mlp.experts.gate_up_proj", + operations=[MergeModulelist(dim=0), Concatenate(dim=1)], + ), + WeightConverter( + source_patterns="mlp.experts.*.w2.weight", + target_patterns="mlp.experts.down_proj", + operations=[MergeModulelist(dim=0)], + ), + ], + "lfm2_moe": [ + WeightConverter( + source_patterns=[ + "feed_forward.experts.*.w1.weight", + "feed_forward.experts.*.w3.weight", + ], + target_patterns="feed_forward.experts.gate_up_proj", + operations=[MergeModulelist(dim=0), Concatenate(dim=1)], + ), + WeightConverter( + source_patterns="feed_forward.experts.*.w2.weight", + target_patterns="feed_forward.experts.down_proj", + operations=[MergeModulelist(dim=0)], + ), + ], + "jamba": [ + WeightConverter( + source_patterns=[ + "feed_forward.experts.*.gate_proj.weight", + "feed_forward.experts.*.up_proj.weight", + ], + target_patterns="feed_forward.experts.gate_up_proj", + operations=[MergeModulelist(dim=0), Concatenate(dim=1)], + ), + WeightConverter( + source_patterns="feed_forward.experts.*.down_proj.weight", + target_patterns="feed_forward.experts.down_proj", + operations=[MergeModulelist(dim=0)], + ), + ], "timm_wrapper": [ # Simply add the prefix `timm_model` # TODO: Would be probably much cleaner with a `add_prefix` argument in WeightRenaming @@ -117,16 +162,13 @@ def _build_checkpoint_conversion_mapping(): ), ] - mapping["phimoe"] = mapping["mixtral"].copy() mapping["deepseek_v2"] = mapping["qwen2_moe"].copy() mapping["deepseek_v3"] = mapping["qwen2_moe"].copy() - mapping["dot1"] = mapping["qwen2_moe"].copy() - mapping["ernie_4_5_moe"] = mapping["qwen2_moe"].copy() + mapping["dots1"] = mapping["qwen2_moe"].copy() + mapping["ernie4_5_moe"] = mapping["qwen2_moe"].copy() mapping["glm4_moe"] = mapping["qwen2_moe"].copy() mapping["glm4v_moe"] = mapping["qwen2_moe"].copy() - mapping["jamba"] = mapping["qwen2_moe"].copy() - mapping["lfm2_moe"] = mapping["mixtral"].copy() - mapping["long_cat_flash"] = mapping["qwen2_moe"].copy() + mapping["longcat_flash"] = mapping["qwen2_moe"].copy() mapping["qwen3_moe"] = mapping["qwen2_moe"].copy() mapping["qwen3_omni_moe"] = mapping["qwen2_moe"].copy() mapping["qwen3_next"] = mapping["qwen2_moe"].copy() diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index f4ec3a7f38ca..8b0b9e92d898 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -48,27 +48,6 @@ logger = logging.get_logger(__name__) -def compile_glob_rule(source_glob: str, target_glob: str) -> tuple[re.Pattern, str]: - """ - Convert a glob-style source + target into a full regex + replacement. - - Rules: - - '*' in source_glob → (.*) capture group - - '*' in target_glob → \\1, \\2, ... backrefs - """ - regex = re.compile(source_glob) - - counter = 0 - - def _star_to_backref(_: re.Match) -> str: - nonlocal counter - counter += 1 - return rf"\{counter}" - - replacement = re.sub(r"\*", _star_to_backref, target_glob) - return regex, replacement - - def build_glob_alternation( globs: list[Union[WeightRenaming, WeightConverter, str]], ) -> tuple[re.Pattern, dict[str, str], dict[str, str]]: @@ -300,6 +279,7 @@ def convert( class WeightTransform: source_patterns: Union[str, list[str]] = field(init=True) target_patterns: Union[str, list[str]] = field(init=True) + compiled_sources: re.Pattern = field(init=False) distributed_operation: Optional[TensorParallelLayer] = None quantization_operation: Optional[ConversionOps] = None @@ -319,20 +299,27 @@ def __post_init__(self): for i, pattern in enumerate(self.target_patterns): # Some mapping contains `^` to notify start of string when matching -> remove it during reverse mapping pattern = pattern.removeprefix("^") - # This is ugly but needed for reverse mapping of Qwen2.5! - if r"(?!\.(language_model|visual))" in pattern: - pattern = pattern.replace(r"(?!\.(language_model|visual))", "") - # Allow capturing groups in patterns, i.e. to add a prefix to all keys (e.g. timm_wrapper) + # Remove negative lookahead if any. This is ugly but needed for reverse mapping of Qwen2.5 and Sam3! + pattern = re.sub(r"\(\?!.+\)", "", pattern) + # Allow capturing groups in patterns, i.e. to add/remove a prefix to all keys (e.g. timm_wrapper, sam3) if r"(.+)" in pattern: - pattern = pattern.replace(r"(.+)", "") + pattern = pattern.replace(r"(.+)", r"\1") self.target_patterns[i] = pattern - # We also need to check capturing groups in the sources during reverse mapping (e.g. timm_wrapper) + # We also need to check capturing groups in the sources during reverse mapping (e.g. timm_wrapper, sam3) for i, pattern in enumerate(self.source_patterns): if r"\1" in pattern: - pattern = pattern.replace(r"\1", "") + pattern = pattern.replace(r"\1", r"(.+)") self.source_patterns[i] = pattern + # Construct the regex we will use to rename keys from the sources to the targets + branches = [] + for i, source_pattern in enumerate(self.source_patterns): + group_name = f"g{i}" + pattern = source_pattern.replace(".*.", r"\..*\.") + branches.append(f"(?P<{group_name}>{pattern})") + self.compiled_sources = re.compile("|".join(branches)) + def add_tensor(self, target_key: str, source_key: str, source_pattern: str, future: Future): self.collected_tensors[source_pattern].append(future) self.layer_targets[target_key].add(source_key) @@ -341,6 +328,32 @@ def reset(self) -> None: """Clean-up the collected tensors to make sure we don't keep references to past tensors in memory.""" self.collected_tensors = defaultdict(list) + def rename_source_key(self, source_key: str) -> tuple[str, str | None]: + """ + Return a tuple (renamed_key, source_pattern_producing_the_match). + Try renaming `source_key` according to the source and target patterns of the current WeightTransform. + In case of a one-to-many transform, i.e. we have several target patterns, the matching source pattern + will be replaced by the first of all the target patterns (they are then correctly expanded in the Operations). + """ + # Try matching one of the alternation branches + match_object = self.compiled_sources.search(source_key) + if match_object is None: + return source_key, None + # Find the source that produced the match (it's the first group that matched, as the search stops after first branch match) + matching_group_name = next(name for name, val in match_object.groupdict().items() if val is not None) + source_pattern_that_matched = self.source_patterns[int(matching_group_name[1:])] + # If we matched, we always replace with the first target pattern, in case we have several (one to many transform) + replacement = self.target_patterns[0] + # # Allow capturing groups in patterns, i.e. to add a prefix to all keys (e.g. timm_wrapper, sam3) + if r"\1" in replacement: + # The index of the internal group we need to replace is the index of the matched named group as it comes + # inside that matched named group + replaced_group_idx = self.compiled_sources.groupindex[matching_group_name] + 1 + replacement = replacement.replace(r"\1", match_object.group(replaced_group_idx)) + renamed_key = source_key.replace(match_object.group(0), replacement) + + return renamed_key, source_pattern_that_matched + def reverse_transform(self) -> WeightTransform: """Reverse the current `WeightTransform` instance, to be able to save with the opposite weight transformations.""" # TODO: check this and relax when quantizer have `reverse_op` @@ -610,54 +623,30 @@ class SkipLayer(Exception): pass -def repl(m, repl_map: dict[str, str]) -> str: - # Collect all groups that matched - matched_groups = [name for name, val in m.groupdict().items() if val] - - if len(matched_groups) == 0: - # Should never happen - return m.group(0) - - if len(matched_groups) > 1: - raise ValueError( - "only a single match should happen, your regex patterns are tangled: " - f"groups matched = {matched_groups} for the patternsL {repl_map.keys()}" - ) - - # Exactly one match => return replacement - name = matched_groups[0] - replacement = repl_map[name] - # Allow capturing groups in patterns, i.e. to add a prefix to all keys (e.g. timm_wrapper) - if r"\1" in replacement and len(m.groups()) > 1: - replacement = replacement.replace(r"\1", m.group(1)) - - return replacement - - def rename_source_key( source_key: str, - rename_alternation: re.Pattern, - rename_by_group: dict, - weight_pattern_alternation: re.Pattern | None, - weight_pattern_by_group: dict | None, + weight_renamings: list[WeightRenaming], + weight_converters: list[WeightConverter], prefix: str | None = None, meta_state_dict: dict | None = None, -) -> tuple[str, re.Match | None]: +) -> tuple[str, str | None]: """ Rename a source key given all the renaming and weight conversion patterns we have. Also takes care of adding/removing the base model prefix during loading if necesary. """ - # 1. apply all renamings - renamed_key = rename_alternation.sub(lambda m: repl(m, rename_by_group), source_key).replace("\\", "") - - # 2. apply renaming through weight conversions on the key if we have any WeightConverter - matched_converter_pattern = ( - weight_pattern_alternation.search(renamed_key) if weight_pattern_alternation is not None else None - ) - if matched_converter_pattern is not None: - renamed_key = weight_pattern_alternation.sub(lambda m: repl(m, weight_pattern_by_group), renamed_key).replace( - "\\", "" - ) + renamed_key = source_key + # 1. apply all renamings in turns (if multiple match, it's the responsibility of the mappings to make sure they + # are coherent) + for renaming in weight_renamings: + renamed_key, _ = renaming.rename_source_key(renamed_key) + + # 2. apply renaming through weight conversions on the key if we have any WeightConverter (here we stop after + # the first match, as we assume only 1 converter can match any source key) + source_pattern = None + for converter in weight_converters: + renamed_key, source_pattern = converter.rename_source_key(renamed_key) + if source_pattern is not None: + break # 3. check if we need to add or remove prefix if necesary (only during loading, not saving) if prefix is not None and meta_state_dict is not None: @@ -669,7 +658,7 @@ def rename_source_key( elif meta_state_dict.get(f"{prefix}.{renamed_key}") is not None: renamed_key = f"{prefix}.{renamed_key}" - return renamed_key, matched_converter_pattern + return renamed_key, source_pattern def convert_and_load_state_dict_in_model( @@ -796,10 +785,6 @@ def convert_and_load_state_dict_in_model( # build '(?P.*.*\\.block_sparse_moe\\..*)' and group to source {'g0': '*.block_sparse_moe.'} # and target to source {'g0': '*.mlp.'}. This allows us to quickly find which pattern matched. - rename_alt, _, rename_by_group = build_glob_alternation(renamings) - weight_pattern_alt, src_group_to_glob, tgt_group_to_glob = None, None, None - if converters != []: - weight_pattern_alt, src_group_to_glob, tgt_group_to_glob = build_glob_alternation(converters) if tp_plan != {}: tp_plan_alt, tp_plan_by_group_name, _ = build_glob_alternation(list(tp_plan.keys())) if dtype_plan != {}: @@ -810,24 +795,19 @@ def convert_and_load_state_dict_in_model( state_dict = sorted(state_dict.items(), key=lambda kv: dot_natural_key(kv[0])) for original_key, tensor in state_dict: # 1. Rename the key according to all renaming pattern and optional weight converter patterns - renamed_key, matched_pattern = rename_source_key( - original_key, - rename_alt, - rename_by_group, - weight_pattern_alt, - tgt_group_to_glob, - prefix, - meta_model_state_dict, + renamed_key, source_pattern = rename_source_key( + original_key, renamings, converters, prefix, meta_model_state_dict ) # 2. finally, collect the tensor into the proper converter if renamed_key in missing_keys: empty_param = meta_model_state_dict.get(renamed_key) - if matched_pattern: - new_converter = deepcopy(pattern_to_converter[src_group_to_glob[matched_pattern.lastgroup]]) + # If we enter here, we have a WeightConverter operation to perform + if source_pattern is not None: + new_converter = deepcopy(pattern_to_converter[source_pattern]) # each target key gets its own converter instance mapping = param_name_to_load.setdefault(renamed_key, new_converter) - source_pattern = src_group_to_glob[matched_pattern.lastgroup] + # Otherwise, only potential renaming else: mapping = param_name_to_load.setdefault(renamed_key, WeightRenaming(original_key, renamed_key)) source_pattern = original_key @@ -879,8 +859,8 @@ def convert_and_load_state_dict_in_model( future = spawn_materialize(thread_pool, tensor, param_device, _dtype) mapping.add_tensor(renamed_key, original_key, source_pattern, future) - elif matched_pattern: # add all target keys as unexpected - mapping = pattern_to_converter[src_group_to_glob[matched_pattern.lastgroup]] + elif source_pattern is not None: # add all target keys as unexpected + mapping = pattern_to_converter[source_pattern] for k in mapping.target_patterns: unexpected_keys.add(renamed_key.replace(mapping.target_patterns[0], k)) else: @@ -961,24 +941,14 @@ def revert_weight_conversion(model: PreTrainedModel, state_dict: dict[str, torch pattern_to_converter = {k: converter for converter in converters for k in converter.source_patterns} conversion_mapping = {} - # build '(?P.*.*\\.block_sparse_moe\\..*)' and group to source {'g0': '*.block_sparse_moe.'} - # and target to source {'g0': '*.mlp.'}. This allows us to quickly find which pattern matched. - rename_alt, _, rename_by_group = build_glob_alternation(renamings) - weight_pattern_alt, src_group_to_glob, tgt_group_to_glob = None, None, None - if converters != []: - weight_pattern_alt, src_group_to_glob, tgt_group_to_glob = build_glob_alternation(converters) - state_dict = sorted(state_dict.items(), key=lambda kv: dot_natural_key(kv[0])) for original_key, tensor in state_dict: # Rename the key according to all renaming pattern and optional weight converter patterns - renamed_key, matched_pattern = rename_source_key( - original_key, rename_alt, rename_by_group, weight_pattern_alt, tgt_group_to_glob - ) - if matched_pattern is not None: - new_converter = deepcopy(pattern_to_converter[src_group_to_glob[matched_pattern.lastgroup]]) + renamed_key, source_pattern = rename_source_key(original_key, renamings, converters) + if source_pattern is not None: + new_converter = deepcopy(pattern_to_converter[source_pattern]) # each target key gets its own converter instance mapping = conversion_mapping.setdefault(renamed_key, new_converter) - source_pattern = src_group_to_glob[matched_pattern.lastgroup] else: mapping = conversion_mapping.setdefault(renamed_key, WeightRenaming(original_key, renamed_key)) source_pattern = original_key diff --git a/src/transformers/integrations/accelerate.py b/src/transformers/integrations/accelerate.py index 570577fca823..70142f2bf296 100644 --- a/src/transformers/integrations/accelerate.py +++ b/src/transformers/integrations/accelerate.py @@ -480,18 +480,15 @@ def accelerate_disk_offload( renamed) will be mapped to where they already reside on disk. Otherwise, the parameters will be resaved inside `disk_offload_folder` during loading. """ - from ..core_model_loading import WeightRenaming, build_glob_alternation, repl + from ..core_model_loading import WeightRenaming, rename_source_key if disk_offload_folder is not None: os.makedirs(disk_offload_folder, exist_ok=True) is_offloaded_safetensors = checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors") - rename = False + renamings = [] if weight_mapping is not None: renamings = [entry for entry in weight_mapping if isinstance(entry, WeightRenaming)] - if len(renamings) > 0: - rename = True - rename_alt, _, rename_by_group = build_glob_alternation(renamings) # In this case, the offload index is simply the existing safetensors (except if using custom weight loading # Operation, e.g. the MoE models, where we need to resave the weights that were changed at loading time) @@ -505,10 +502,7 @@ def accelerate_disk_offload( weight_map = {k: os.path.join(folder, v) for k, v in sharded_metadata["weight_map"].items()} # Update the weight names according to the `weight_mapping` - weight_renaming_map = { - rename_alt.sub(lambda m: repl(m, rename_by_group), k).replace("\\", "") if rename else k: k - for k in weight_map - } + weight_renaming_map = {rename_source_key(k, renamings, [])[0]: k for k in weight_map} # Prepare the index using existing safetensors files disk_offload_index = { diff --git a/src/transformers/integrations/peft.py b/src/transformers/integrations/peft.py index 998af1491cff..4c1c12a4058e 100644 --- a/src/transformers/integrations/peft.py +++ b/src/transformers/integrations/peft.py @@ -17,7 +17,7 @@ import os from typing import Any, Literal -from ..core_model_loading import WeightRenaming, build_glob_alternation, repl +from ..core_model_loading import WeightRenaming, rename_source_key from ..utils import ( CONFIG_NAME, cached_file, @@ -294,6 +294,9 @@ def load_adapter( adapter_state_dict = load_peft_weights(peft_model_id, token=token, device=device, **adapter_kwargs) # We need to pre-process the state dict to remove unneeded prefixes - for backward compatibility + renamings = [] + if key_mapping: + renamings = [entry for entry in key_mapping if isinstance(entry, WeightRenaming)] processed_adapter_state_dict = {} prefix = "base_model.model." for key, value in adapter_state_dict.items(): @@ -302,10 +305,7 @@ def load_adapter( else: new_key = key - if key_mapping: - renamings = [entry for entry in key_mapping if isinstance(entry, WeightRenaming)] - rename_alt, _, rename_by_group = build_glob_alternation(renamings) - new_key = rename_alt.sub(lambda m: repl(m, rename_by_group), new_key).replace("\\", "") + new_key = rename_source_key(new_key, renamings, [])[0] # For hotswapping, we need the adapter name to be present in the state dict keys if hotswap: diff --git a/src/transformers/models/lfm2_moe/configuration_lfm2_moe.py b/src/transformers/models/lfm2_moe/configuration_lfm2_moe.py index cea82636927b..d9deffefdc82 100644 --- a/src/transformers/models/lfm2_moe/configuration_lfm2_moe.py +++ b/src/transformers/models/lfm2_moe/configuration_lfm2_moe.py @@ -54,6 +54,8 @@ class Lfm2MoeConfig(PreTrainedConfig): with longer `max_position_embeddings`. max_position_embeddings (`int`, *optional*, defaults to 128000): The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if `config.is_decoder=True`. @@ -117,6 +119,7 @@ def __init__( tie_word_embeddings: bool = True, rope_parameters: RopeParameters = None, max_position_embeddings: int = 128_000, + initializer_range: float = 0.02, use_cache: bool = True, norm_eps: float = 0.00001, num_attention_heads: int = 32, @@ -140,6 +143,7 @@ def __init__( rope_scaling = kwargs.pop("rope_scaling", None) self.rope_parameters = rope_scaling or rope_parameters self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range self.use_cache = use_cache self.norm_eps = norm_eps diff --git a/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py b/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py index 73b9c4a8fde0..8df14fe005ef 100644 --- a/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py +++ b/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py @@ -25,7 +25,6 @@ from torch import nn from ... import initialization as init -from ...activations import ACT2FN from ...cache_utils import Cache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub @@ -156,7 +155,6 @@ def __init__(self, config): self.intermediate_dim = config.moe_intermediate_size self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) - self.act_fn = ACT2FN[config.hidden_act] def forward( self, @@ -165,22 +163,21 @@ def forward( top_k_weights: torch.Tensor, ) -> torch.Tensor: final_hidden_states = torch.zeros_like(hidden_states) - num_experts = top_k_weights.shape[1] with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1) + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) expert_mask = expert_mask.permute(2, 1, 0) expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: expert_idx = expert_idx[0] - if expert_idx == num_experts: + if expert_idx == self.num_experts: continue - _, token_idx = torch.where(expert_mask[expert_idx]) + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) current_state = hidden_states[token_idx] gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) - current_hidden_states = self.act_fn(gate) * up + current_hidden_states = F.silu(gate) * up current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) - current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None] + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) return final_hidden_states diff --git a/src/transformers/models/lfm2_moe/modular_lfm2_moe.py b/src/transformers/models/lfm2_moe/modular_lfm2_moe.py index 1ada897f00db..f4549b2e6aa2 100644 --- a/src/transformers/models/lfm2_moe/modular_lfm2_moe.py +++ b/src/transformers/models/lfm2_moe/modular_lfm2_moe.py @@ -14,6 +14,7 @@ from typing import Optional import torch +import torch.nn.functional as F from torch import nn from ... import initialization as init @@ -69,7 +70,35 @@ def __init__(self, config: Lfm2MoeConfig, intermediate_size: Optional[int] = Non class Lfm2MoeExperts(Qwen2MoeExperts): - pass + def __init__(self, config): + super().__init__(config) + del self.act_fn + + def forward( + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + ) -> torch.Tensor: + final_hidden_states = torch.zeros_like(hidden_states) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + + for expert_idx in expert_hit: + expert_idx = expert_idx[0] + if expert_idx == self.num_experts: + continue + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) + current_hidden_states = F.silu(gate) * up + current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) + + return final_hidden_states class Lfm2MoeSparseMoeBlock(nn.Module): diff --git a/src/transformers/models/sam3/modeling_sam3.py b/src/transformers/models/sam3/modeling_sam3.py index de12b057aaa7..5a407efe58bb 100644 --- a/src/transformers/models/sam3/modeling_sam3.py +++ b/src/transformers/models/sam3/modeling_sam3.py @@ -2089,7 +2089,9 @@ def _embed_pixels( class Sam3Model(Sam3PreTrainedModel): input_modalities = ["image", "text"] - _checkpoint_conversion_mapping = {"detector_model.": ""} + _checkpoint_conversion_mapping = { + r"detector_model.(.+)": r"\1" # the regex allows to remove the prefix, and add it back in revert mode + } _keys_to_ignore_on_load_unexpected = [ r"^tracker_model.", r"^tracker_neck.", diff --git a/src/transformers/models/sam3_tracker/modeling_sam3_tracker.py b/src/transformers/models/sam3_tracker/modeling_sam3_tracker.py index 829fe01b9436..8a1a77eaafd9 100644 --- a/src/transformers/models/sam3_tracker/modeling_sam3_tracker.py +++ b/src/transformers/models/sam3_tracker/modeling_sam3_tracker.py @@ -768,7 +768,7 @@ class Sam3TrackerModel(Sam3TrackerPreTrainedModel): "occlusion_spatial_embedding_parameter", ] _checkpoint_conversion_mapping = { - "tracker_model.": "", + r"tracker_model.(.+)": r"\1", # the regex allows to remove the prefix, and add it back in revert mode "detector_model.vision_encoder.backbone.": "vision_encoder.backbone.", "tracker_neck.": "vision_encoder.neck.", } diff --git a/src/transformers/models/sam3_tracker/modular_sam3_tracker.py b/src/transformers/models/sam3_tracker/modular_sam3_tracker.py index dee826adef91..18937f5455f6 100644 --- a/src/transformers/models/sam3_tracker/modular_sam3_tracker.py +++ b/src/transformers/models/sam3_tracker/modular_sam3_tracker.py @@ -180,7 +180,7 @@ class Sam3TrackerMaskDecoder(Sam2MaskDecoder): class Sam3TrackerModel(Sam2Model): _checkpoint_conversion_mapping = { - "tracker_model.": "", + r"tracker_model.(.+)": r"\1", # the regex allows to remove the prefix, and add it back in revert mode "detector_model.vision_encoder.backbone.": "vision_encoder.backbone.", "tracker_neck.": "vision_encoder.neck.", } diff --git a/src/transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py b/src/transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py index d31ff21f80d3..125b86bacea0 100644 --- a/src/transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +++ b/src/transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py @@ -1570,7 +1570,7 @@ class Sam3TrackerVideoModel(Sam3TrackerVideoPreTrainedModel): _tied_weights_keys = {} _keys_to_ignore_on_load_missing = [] _checkpoint_conversion_mapping = { - "tracker_model.": "", + r"tracker_model.(.+)": r"\1", # the regex allows to remove the prefix, and add it back in revert mode "detector_model.vision_encoder.backbone.": "vision_encoder.backbone.", "tracker_neck.": "vision_encoder.neck.", } diff --git a/src/transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py b/src/transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py index cf8fdcdf303a..12e4bc8854e0 100644 --- a/src/transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +++ b/src/transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py @@ -456,7 +456,7 @@ class Sam3TrackerVideoMaskDecoder(Sam2VideoMaskDecoder): class Sam3TrackerVideoModel(Sam2VideoModel): _checkpoint_conversion_mapping = { - "tracker_model.": "", + r"tracker_model.(.+)": r"\1", # the regex allows to remove the prefix, and add it back in revert mode "detector_model.vision_encoder.backbone.": "vision_encoder.backbone.", "tracker_neck.": "vision_encoder.neck.", } diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index c1f7c1b83ed0..d2b5e0949cac 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4058,7 +4058,6 @@ def test_tp_plan_matches_params(self): len(unused_entries) == 0, f"The following entries of the TP-plan are not valid: {unused_entries}" ) - @unittest.skip("Some models have wrong mappings....") def test_reverse_loading_mapping(self): """Make sure we can load and save correctly the models having any weight renaming mapping or weight conversion mapping. @@ -4073,7 +4072,7 @@ def test_reverse_loading_mapping(self): # Some MoE models alternate between a classic MLP and a MoE layer, in which case we want to have at # lest one MoE layer here to check the mapping - config_to_set = config.get_text_config() + config_to_set = config.get_text_config(decoder=True) config_to_set.first_k_dense_replace = 1 # means that the first layer (idx 0) will be MLP, then MoE config_to_set.moe_layer_start_index = 1 # same as above but for Ernie 4.5... config_to_set.mlp_only_layers = [0] # same but for qwens @@ -4137,7 +4136,6 @@ def test_reverse_loading_mapping(self): # Make sure both saved state_dict are identical self.assertTrue(compare_state_dicts(state_dict_saved_from_init, state_dict_saved_from_pretrained)) - @unittest.skip("Some models have wrong mappings....") def test_can_load_from_already_mapped_keys(self): """Test that we can correctly reload a model if we chose `save_original_format=False` in `save_pretrained`, i.e. we do not reapply weight conversions when reloading if it was saved correctly already. diff --git a/tests/utils/test_core_model_loading.py b/tests/utils/test_core_model_loading.py index 88bdb27256ba..6198702fe4a9 100644 --- a/tests/utils/test_core_model_loading.py +++ b/tests/utils/test_core_model_loading.py @@ -27,7 +27,7 @@ WeightRenaming, build_glob_alternation, convert_and_load_state_dict_in_model, - repl, + rename_source_key, revert_weight_conversion, ) from transformers.utils.import_utils import is_triton_available @@ -138,23 +138,24 @@ def test_sub_key_rewrites_targets(self): WeightRenaming("block_sparse_moe.experts.*.w2.weight", "mlp.experts.down_proj"), WeightRenaming("model.language_model.*", "language_model"), ] - rename_alt, _, rename_by_group = build_glob_alternation(renamings) - def rename(original_key: str) -> str: - return rename_alt.sub(lambda m: repl(m, rename_by_group), original_key).replace("\\", "") - - self.assertEqual(rename("foo.block_sparse_moe.experts.3.w1.weight"), "foo.mlp.experts.gate_up_proj") - self.assertEqual(rename("foo.block_sparse_moe.experts.3.w2.weight"), "foo.mlp.experts.down_proj") - self.assertEqual(rename("model.language_model.lm_head.weight"), "language_model") + self.assertEqual( + rename_source_key("foo.block_sparse_moe.experts.3.w1.weight", renamings, [])[0], + "foo.mlp.experts.gate_up_proj", + ) + self.assertEqual( + rename_source_key("foo.block_sparse_moe.experts.3.w2.weight", renamings, [])[0], + "foo.mlp.experts.down_proj", + ) + self.assertEqual(rename_source_key("model.language_model.lm_head.weight", renamings, [])[0], "language_model") def test_sub_key_no_match_returns_original(self): renamings = [ WeightRenaming("block_sparse_moe.experts.*.w1.weight", "*.mlp.experts.gate_up_proj"), ] - rename_alt, _, rename_by_group = build_glob_alternation(renamings) key = "unrelated.key" - renamed_key = rename_alt.sub(lambda m: repl(m, rename_by_group), key).replace("\\", "") + renamed_key, _ = rename_source_key(key, renamings, []) self.assertEqual(renamed_key, key)