Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion src/transformers/conversion_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,21 @@ def _build_checkpoint_conversion_mapping():
operations=[MergeModulelist(dim=0)],
),
],
"qwen3_vl_moe": [
WeightConverter(
source_patterns=[
"mlp.experts.*.gate_proj.weight",
"mlp.experts.*.up_proj.weight",
],
target_patterns="mlp.experts.gate_up_proj",
operations=[MergeModulelist(dim=0), Concatenate(dim=1), Transpose(1, 2)],
),
WeightConverter(
source_patterns="mlp.experts.*.down_proj.weight",
target_patterns="mlp.experts.down_proj",
operations=[MergeModulelist(dim=0), Transpose(1, 2)],
),
],
"phimoe": [
WeightConverter(
source_patterns=[
Expand Down Expand Up @@ -234,7 +249,6 @@ def _build_checkpoint_conversion_mapping():
mapping["qwen3_moe"] = mapping["qwen2_moe"].copy()
mapping["qwen3_omni_moe"] = mapping["qwen2_moe"].copy()
mapping["qwen3_next"] = mapping["qwen2_moe"].copy()
mapping["qwen3_vl_moe"] = mapping["qwen2_moe"].copy()
mapping["hunyuan_v1_moe"] = mapping["qwen2_moe"].copy()
mapping["minimax"] = mapping["mixtral"].copy()
mapping["minimax_m2"] = mapping["mixtral"].copy()
Expand Down
84 changes: 44 additions & 40 deletions src/transformers/core_model_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,12 @@ def convert(
) -> dict[str, torch.Tensor]:
tensors = next(iter(input_dict.values()))
tensor = tensors[0] if isinstance(tensors, list) else tensors
targets = self.get_target_pattern(input_dict, target_patterns)
targets = self.get_target_patterns(input_dict, target_patterns)
sizes = len(targets)
chunks = torch.chunk(tensor, sizes, dim=self.dim)
return dict(zip(targets, chunks))

def get_target_pattern(self, input_dict: dict, target_patterns: list[str]) -> list[str]:
def get_target_patterns(self, input_dict: dict, target_patterns: list[str]) -> list[str]:
# Here we always return the target patterns
if len(input_dict) > 1 or len(target_patterns) == 1:
raise ValueError("Undefined Operation encountered!")
Expand Down Expand Up @@ -245,6 +245,44 @@ def reverse_op(self) -> ConversionOps:
return MergeModulelist(self.dim)


class Transpose(ConversionOps):
"""
Transposes the given tensor along dim0 and dim1.
"""

def __init__(self, dim0: int = 0, dim1: int = 1):
self.dim0 = dim0
self.dim1 = dim1
Comment on lines +253 to +255
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we pass the full converter to change the shard dim at init?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, but aren't you assuming that we need to change the shard dim based on transpose? At least the transpose that do use it make it so that they are aligned with other models, i.e. they will need to be sharded the same way as intended.

There might be a few models that shard differently tho, I would consider this not part of the conversion op tho - otherwise we will mix in our assumptions 🤔

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's due to ops order - sharding happens before Transpose, so then if you Transpose the dim that was sharded on, you've got an issue

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I messed up the order, I thought it was transpose then shard

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Life would be too nice... 😆


@torch.no_grad
def convert(
self, input_dict: dict[str, torch.Tensor], source_patterns: list[str], target_patterns: list[str], **kwargs
) -> dict[str, torch.Tensor]:
target_pattern = self.get_target_pattern(input_dict, source_patterns, target_patterns)
tensors = next(iter(input_dict.values()))
tensor = tensors[0] if isinstance(tensors, list) else tensors
return {target_pattern: torch.transpose(tensor, dim0=self.dim0, dim1=self.dim1).contiguous()}

def get_target_pattern(
self, input_dict: dict[str, torch.Tensor], source_patterns: list[str], target_patterns: list[str]
) -> str:
if len(input_dict) != 1:
raise ValueError("Undefined Operation encountered!")
# Here it's the first operation of a chain, so return the source
if len(target_patterns) > 1:
if len(source_patterns) == 1:
return source_patterns[0]
else:
raise ValueError("Undefined Operation encountered!")
# Here it's the only operation, or the last operation in a chain, so we return the target
else:
return target_patterns[0]
Comment on lines +266 to +279
Copy link
Member

@IlyasMoutawwakil IlyasMoutawwakil Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so it has to either be the first or the last in a chain of ops ? I can see the dequantization ops breaking this as they extend the chain 🥲 correct me if i'm wrong

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nop, should be alright with quantization! They do not change the targets/sources!


@property
def reverse_op(self) -> ConversionOps:
return Transpose(dim0=self.dim1, dim1=self.dim0)


class PermuteForRope(ConversionOps):
"""
Applies the permutation required to convert complex RoPE weights to the split sin/cos format.
Expand Down Expand Up @@ -402,43 +440,6 @@ def reverse_op(self) -> ConversionOps:
return ErnieFuseAndSplitTextVisionExperts(stack_dim=self.stack_dim, concat_dim=self.concat_dim)


class Transpose(ConversionOps):
"""
Transposes the given tensor along dim0 and dim1.
"""

def __init__(self, dim0: int = 0, dim1: int = 1):
self.dim0 = dim0
self.dim1 = dim1

@torch.no_grad()
def convert(
self,
input_dict: dict[str, list[torch.Tensor]],
source_patterns: list[str],
target_patterns: list[str],
config,
**kwargs,
) -> dict[str, list[torch.Tensor]]:
if len(input_dict) != len(target_patterns):
raise ValueError(
f"Transpose conversion can only happen on each key ({len(input_dict)}) "
f"and should match exact one target ({len(target_patterns)})."
)

output: dict[str, list[torch.Tensor]] = {}
for key, target_pattern in zip(input_dict.keys(), target_patterns):
tensor = input_dict.get(key, [])
if len(tensor) != 1:
raise ValueError(f"Transpose conversion requires exactly one tensor, found {len(tensor)}.")
output[target_pattern] = torch.transpose(tensor[0], dim0=self.dim0, dim1=self.dim1).contiguous()
return output

@property
def reverse_op(self) -> ConversionOps:
return Transpose(dim0=self.dim1, dim1=self.dim0)


@dataclass(slots=True)
class WeightTransform:
source_patterns: str | list[str] = field(init=True)
Expand Down Expand Up @@ -739,7 +740,7 @@ def dot_natural_key(s: str):
@contextmanager
def log_conversion_errors(
first_target_key: str,
conversion_errors: MutableMapping[str, str],
conversion_errors: MutableMapping[str, str] | None,
extras: Any = None,
op: list[ConversionOps] | ConversionOps | None = None,
):
Expand All @@ -748,6 +749,9 @@ def log_conversion_errors(
try:
yield
except Exception as e:
# During reverse mapping, we do not log and skip errors
if conversion_errors is None:
raise e

def _format_op_name(curr_op: list[ConversionOps] | ConversionOps | None) -> str | None:
if curr_op is None:
Expand Down