-
Notifications
You must be signed in to change notification settings - Fork 32.5k
[loading] Fix Transpose Operation, and qwen3_vl_moe mapping #43307
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
873cf5b
001c604
093713e
8f51f29
73f13e2
306333b
9a937ac
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -113,12 +113,12 @@ def convert( | |
| ) -> dict[str, torch.Tensor]: | ||
| tensors = next(iter(input_dict.values())) | ||
| tensor = tensors[0] if isinstance(tensors, list) else tensors | ||
| targets = self.get_target_pattern(input_dict, target_patterns) | ||
| targets = self.get_target_patterns(input_dict, target_patterns) | ||
| sizes = len(targets) | ||
| chunks = torch.chunk(tensor, sizes, dim=self.dim) | ||
| return dict(zip(targets, chunks)) | ||
|
|
||
| def get_target_pattern(self, input_dict: dict, target_patterns: list[str]) -> list[str]: | ||
| def get_target_patterns(self, input_dict: dict, target_patterns: list[str]) -> list[str]: | ||
| # Here we always return the target patterns | ||
| if len(input_dict) > 1 or len(target_patterns) == 1: | ||
| raise ValueError("Undefined Operation encountered!") | ||
|
|
@@ -245,6 +245,44 @@ def reverse_op(self) -> ConversionOps: | |
| return MergeModulelist(self.dim) | ||
|
|
||
|
|
||
| class Transpose(ConversionOps): | ||
| """ | ||
| Transposes the given tensor along dim0 and dim1. | ||
| """ | ||
|
|
||
| def __init__(self, dim0: int = 0, dim1: int = 1): | ||
| self.dim0 = dim0 | ||
| self.dim1 = dim1 | ||
|
|
||
| @torch.no_grad | ||
| def convert( | ||
| self, input_dict: dict[str, torch.Tensor], source_patterns: list[str], target_patterns: list[str], **kwargs | ||
| ) -> dict[str, torch.Tensor]: | ||
| target_pattern = self.get_target_pattern(input_dict, source_patterns, target_patterns) | ||
| tensors = next(iter(input_dict.values())) | ||
| tensor = tensors[0] if isinstance(tensors, list) else tensors | ||
| return {target_pattern: torch.transpose(tensor, dim0=self.dim0, dim1=self.dim1).contiguous()} | ||
|
|
||
| def get_target_pattern( | ||
| self, input_dict: dict[str, torch.Tensor], source_patterns: list[str], target_patterns: list[str] | ||
| ) -> str: | ||
| if len(input_dict) != 1: | ||
| raise ValueError("Undefined Operation encountered!") | ||
| # Here it's the first operation of a chain, so return the source | ||
| if len(target_patterns) > 1: | ||
| if len(source_patterns) == 1: | ||
| return source_patterns[0] | ||
| else: | ||
| raise ValueError("Undefined Operation encountered!") | ||
| # Here it's the only operation, or the last operation in a chain, so we return the target | ||
| else: | ||
| return target_patterns[0] | ||
|
Comment on lines
+266
to
+279
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so it has to either be the first or the last in a chain of ops ? I can see the dequantization ops breaking this as they extend the chain 🥲 correct me if i'm wrong
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nop, should be alright with quantization! They do not change the targets/sources! |
||
|
|
||
| @property | ||
| def reverse_op(self) -> ConversionOps: | ||
| return Transpose(dim0=self.dim1, dim1=self.dim0) | ||
|
|
||
|
|
||
| class PermuteForRope(ConversionOps): | ||
| """ | ||
| Applies the permutation required to convert complex RoPE weights to the split sin/cos format. | ||
|
|
@@ -402,43 +440,6 @@ def reverse_op(self) -> ConversionOps: | |
| return ErnieFuseAndSplitTextVisionExperts(stack_dim=self.stack_dim, concat_dim=self.concat_dim) | ||
|
|
||
|
|
||
| class Transpose(ConversionOps): | ||
| """ | ||
| Transposes the given tensor along dim0 and dim1. | ||
| """ | ||
|
|
||
| def __init__(self, dim0: int = 0, dim1: int = 1): | ||
| self.dim0 = dim0 | ||
| self.dim1 = dim1 | ||
|
|
||
| @torch.no_grad() | ||
| def convert( | ||
| self, | ||
| input_dict: dict[str, list[torch.Tensor]], | ||
| source_patterns: list[str], | ||
| target_patterns: list[str], | ||
| config, | ||
| **kwargs, | ||
| ) -> dict[str, list[torch.Tensor]]: | ||
| if len(input_dict) != len(target_patterns): | ||
| raise ValueError( | ||
| f"Transpose conversion can only happen on each key ({len(input_dict)}) " | ||
| f"and should match exact one target ({len(target_patterns)})." | ||
| ) | ||
|
|
||
| output: dict[str, list[torch.Tensor]] = {} | ||
| for key, target_pattern in zip(input_dict.keys(), target_patterns): | ||
| tensor = input_dict.get(key, []) | ||
| if len(tensor) != 1: | ||
| raise ValueError(f"Transpose conversion requires exactly one tensor, found {len(tensor)}.") | ||
| output[target_pattern] = torch.transpose(tensor[0], dim0=self.dim0, dim1=self.dim1).contiguous() | ||
| return output | ||
|
|
||
| @property | ||
| def reverse_op(self) -> ConversionOps: | ||
| return Transpose(dim0=self.dim1, dim1=self.dim0) | ||
|
|
||
|
|
||
| @dataclass(slots=True) | ||
| class WeightTransform: | ||
| source_patterns: str | list[str] = field(init=True) | ||
|
|
@@ -739,7 +740,7 @@ def dot_natural_key(s: str): | |
| @contextmanager | ||
| def log_conversion_errors( | ||
| first_target_key: str, | ||
| conversion_errors: MutableMapping[str, str], | ||
| conversion_errors: MutableMapping[str, str] | None, | ||
| extras: Any = None, | ||
| op: list[ConversionOps] | ConversionOps | None = None, | ||
| ): | ||
|
|
@@ -748,6 +749,9 @@ def log_conversion_errors( | |
| try: | ||
| yield | ||
| except Exception as e: | ||
| # During reverse mapping, we do not log and skip errors | ||
| if conversion_errors is None: | ||
| raise e | ||
|
|
||
| def _format_op_name(curr_op: list[ConversionOps] | ConversionOps | None) -> str | None: | ||
| if curr_op is None: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we pass the full converter to change the shard dim at init?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, but aren't you assuming that we need to change the shard dim based on transpose? At least the transpose that do use it make it so that they are aligned with other models, i.e. they will need to be sharded the same way as intended.
There might be a few models that shard differently tho, I would consider this not part of the conversion op tho - otherwise we will mix in our assumptions 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's due to ops order - sharding happens before Transpose, so then if you Transpose the dim that was sharded on, you've got an issue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I messed up the order, I thought it was transpose then shard
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Life would be too nice... 😆