diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index 6f4781524c3b..d9db8aa2ef36 100644 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -82,6 +82,21 @@ def _build_checkpoint_conversion_mapping(): operations=[MergeModulelist(dim=0)], ), ], + "qwen3_vl_moe": [ + WeightConverter( + source_patterns=[ + "mlp.experts.*.gate_proj.weight", + "mlp.experts.*.up_proj.weight", + ], + target_patterns="mlp.experts.gate_up_proj", + operations=[MergeModulelist(dim=0), Concatenate(dim=1), Transpose(1, 2)], + ), + WeightConverter( + source_patterns="mlp.experts.*.down_proj.weight", + target_patterns="mlp.experts.down_proj", + operations=[MergeModulelist(dim=0), Transpose(1, 2)], + ), + ], "phimoe": [ WeightConverter( source_patterns=[ @@ -234,7 +249,6 @@ def _build_checkpoint_conversion_mapping(): mapping["qwen3_moe"] = mapping["qwen2_moe"].copy() mapping["qwen3_omni_moe"] = mapping["qwen2_moe"].copy() mapping["qwen3_next"] = mapping["qwen2_moe"].copy() - mapping["qwen3_vl_moe"] = mapping["qwen2_moe"].copy() mapping["hunyuan_v1_moe"] = mapping["qwen2_moe"].copy() mapping["minimax"] = mapping["mixtral"].copy() mapping["minimax_m2"] = mapping["mixtral"].copy() diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index 9e43baf498b1..92068f127368 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -113,12 +113,12 @@ def convert( ) -> dict[str, torch.Tensor]: tensors = next(iter(input_dict.values())) tensor = tensors[0] if isinstance(tensors, list) else tensors - targets = self.get_target_pattern(input_dict, target_patterns) + targets = self.get_target_patterns(input_dict, target_patterns) sizes = len(targets) chunks = torch.chunk(tensor, sizes, dim=self.dim) return dict(zip(targets, chunks)) - def get_target_pattern(self, input_dict: dict, target_patterns: list[str]) -> list[str]: + def get_target_patterns(self, input_dict: dict, target_patterns: list[str]) -> list[str]: # Here we always return the target patterns if len(input_dict) > 1 or len(target_patterns) == 1: raise ValueError("Undefined Operation encountered!") @@ -245,6 +245,44 @@ def reverse_op(self) -> ConversionOps: return MergeModulelist(self.dim) +class Transpose(ConversionOps): + """ + Transposes the given tensor along dim0 and dim1. + """ + + def __init__(self, dim0: int = 0, dim1: int = 1): + self.dim0 = dim0 + self.dim1 = dim1 + + @torch.no_grad + def convert( + self, input_dict: dict[str, torch.Tensor], source_patterns: list[str], target_patterns: list[str], **kwargs + ) -> dict[str, torch.Tensor]: + target_pattern = self.get_target_pattern(input_dict, source_patterns, target_patterns) + tensors = next(iter(input_dict.values())) + tensor = tensors[0] if isinstance(tensors, list) else tensors + return {target_pattern: torch.transpose(tensor, dim0=self.dim0, dim1=self.dim1).contiguous()} + + def get_target_pattern( + self, input_dict: dict[str, torch.Tensor], source_patterns: list[str], target_patterns: list[str] + ) -> str: + if len(input_dict) != 1: + raise ValueError("Undefined Operation encountered!") + # Here it's the first operation of a chain, so return the source + if len(target_patterns) > 1: + if len(source_patterns) == 1: + return source_patterns[0] + else: + raise ValueError("Undefined Operation encountered!") + # Here it's the only operation, or the last operation in a chain, so we return the target + else: + return target_patterns[0] + + @property + def reverse_op(self) -> ConversionOps: + return Transpose(dim0=self.dim1, dim1=self.dim0) + + class PermuteForRope(ConversionOps): """ Applies the permutation required to convert complex RoPE weights to the split sin/cos format. @@ -402,43 +440,6 @@ def reverse_op(self) -> ConversionOps: return ErnieFuseAndSplitTextVisionExperts(stack_dim=self.stack_dim, concat_dim=self.concat_dim) -class Transpose(ConversionOps): - """ - Transposes the given tensor along dim0 and dim1. - """ - - def __init__(self, dim0: int = 0, dim1: int = 1): - self.dim0 = dim0 - self.dim1 = dim1 - - @torch.no_grad() - def convert( - self, - input_dict: dict[str, list[torch.Tensor]], - source_patterns: list[str], - target_patterns: list[str], - config, - **kwargs, - ) -> dict[str, list[torch.Tensor]]: - if len(input_dict) != len(target_patterns): - raise ValueError( - f"Transpose conversion can only happen on each key ({len(input_dict)}) " - f"and should match exact one target ({len(target_patterns)})." - ) - - output: dict[str, list[torch.Tensor]] = {} - for key, target_pattern in zip(input_dict.keys(), target_patterns): - tensor = input_dict.get(key, []) - if len(tensor) != 1: - raise ValueError(f"Transpose conversion requires exactly one tensor, found {len(tensor)}.") - output[target_pattern] = torch.transpose(tensor[0], dim0=self.dim0, dim1=self.dim1).contiguous() - return output - - @property - def reverse_op(self) -> ConversionOps: - return Transpose(dim0=self.dim1, dim1=self.dim0) - - @dataclass(slots=True) class WeightTransform: source_patterns: str | list[str] = field(init=True) @@ -739,7 +740,7 @@ def dot_natural_key(s: str): @contextmanager def log_conversion_errors( first_target_key: str, - conversion_errors: MutableMapping[str, str], + conversion_errors: MutableMapping[str, str] | None, extras: Any = None, op: list[ConversionOps] | ConversionOps | None = None, ): @@ -748,6 +749,9 @@ def log_conversion_errors( try: yield except Exception as e: + # During reverse mapping, we do not log and skip errors + if conversion_errors is None: + raise e def _format_op_name(curr_op: list[ConversionOps] | ConversionOps | None) -> str | None: if curr_op is None: