-
Notifications
You must be signed in to change notification settings - Fork 267
Refactor and consolidate moe lora extractors #629
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -334,6 +334,163 @@ def _has_lora_adapters(param) -> bool: | |||||||||||||||||||||||||||||||
| return len(param.lora_A) > 0 | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def _canonical_lora_weights_for_grouped_mm( | ||||||||||||||||||||||||||||||||
| weight_A: torch.Tensor, | ||||||||||||||||||||||||||||||||
| weight_B: torch.Tensor, | ||||||||||||||||||||||||||||||||
| num_experts: int, | ||||||||||||||||||||||||||||||||
| rank_per_expert: int, | ||||||||||||||||||||||||||||||||
| dim_A: int, | ||||||||||||||||||||||||||||||||
| dim_B: int, | ||||||||||||||||||||||||||||||||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | ||||||||||||||||||||||||||||||||
| first_weight = weight_A.view(num_experts, rank_per_expert, dim_A) | ||||||||||||||||||||||||||||||||
| first_weight = first_weight.permute(0, 2, 1).contiguous() | ||||||||||||||||||||||||||||||||
| second_weight = weight_B.view(dim_B, num_experts, rank_per_expert) | ||||||||||||||||||||||||||||||||
| second_weight = second_weight.permute(1, 2, 0).contiguous() | ||||||||||||||||||||||||||||||||
| return first_weight, second_weight | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def _reversed_lora_weights_for_grouped_mm( | ||||||||||||||||||||||||||||||||
| weight_A: torch.Tensor, | ||||||||||||||||||||||||||||||||
| weight_B: torch.Tensor, | ||||||||||||||||||||||||||||||||
| num_experts: int, | ||||||||||||||||||||||||||||||||
| rank_per_expert: int, | ||||||||||||||||||||||||||||||||
| dim_A: int, | ||||||||||||||||||||||||||||||||
| dim_B: int, | ||||||||||||||||||||||||||||||||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | ||||||||||||||||||||||||||||||||
| first_weight = weight_B.view(dim_B, num_experts, rank_per_expert) | ||||||||||||||||||||||||||||||||
| first_weight = first_weight.permute(1, 0, 2).contiguous() | ||||||||||||||||||||||||||||||||
| second_weight = weight_A.view(num_experts, rank_per_expert, dim_A).contiguous() | ||||||||||||||||||||||||||||||||
| return first_weight, second_weight | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def _get_param_shape_from_module(module, parameter_name): | ||||||||||||||||||||||||||||||||
| if module is None or parameter_name is None or not hasattr(module, parameter_name): | ||||||||||||||||||||||||||||||||
| return None | ||||||||||||||||||||||||||||||||
| param = getattr(module, parameter_name) | ||||||||||||||||||||||||||||||||
| if hasattr(param, "get_param"): | ||||||||||||||||||||||||||||||||
| param = param.get_param() | ||||||||||||||||||||||||||||||||
| elif hasattr(param, "weight"): | ||||||||||||||||||||||||||||||||
| param = param.weight | ||||||||||||||||||||||||||||||||
| return tuple(param.shape) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def _get_moe_lora_io_dims(wrapper, experts_module=None): | ||||||||||||||||||||||||||||||||
| base = None | ||||||||||||||||||||||||||||||||
| if wrapper is not None and hasattr(wrapper, "get_base_layer"): | ||||||||||||||||||||||||||||||||
| base = wrapper.get_base_layer() | ||||||||||||||||||||||||||||||||
| if experts_module is None: | ||||||||||||||||||||||||||||||||
| experts_module = base | ||||||||||||||||||||||||||||||||
| if experts_module is None: | ||||||||||||||||||||||||||||||||
| experts_module = getattr(wrapper, "base_layer", None) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| parameter_name = getattr(wrapper, "parameter_name", None) | ||||||||||||||||||||||||||||||||
| source = experts_module if experts_module is not None else base | ||||||||||||||||||||||||||||||||
| if source is None: | ||||||||||||||||||||||||||||||||
| return None, None | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| shape = _get_param_shape_from_module(source, parameter_name) | ||||||||||||||||||||||||||||||||
| if shape is not None and len(shape) >= 3: | ||||||||||||||||||||||||||||||||
| grouped_mm_format = bool(getattr(source, "_unsloth_grouped_mm_format", False)) | ||||||||||||||||||||||||||||||||
| if grouped_mm_format: | ||||||||||||||||||||||||||||||||
| return shape[-2], shape[-1] | ||||||||||||||||||||||||||||||||
| return shape[-1], shape[-2] | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| hidden_dim = getattr(source, "hidden_dim", None) | ||||||||||||||||||||||||||||||||
| intermediate_dim = getattr(source, "intermediate_dim", None) | ||||||||||||||||||||||||||||||||
| if hidden_dim is None or intermediate_dim is None: | ||||||||||||||||||||||||||||||||
| return None, None | ||||||||||||||||||||||||||||||||
| if parameter_name == "gate_up_proj": | ||||||||||||||||||||||||||||||||
| return hidden_dim, 2 * intermediate_dim | ||||||||||||||||||||||||||||||||
| if parameter_name == "down_proj": | ||||||||||||||||||||||||||||||||
| return intermediate_dim, hidden_dim | ||||||||||||||||||||||||||||||||
| return None, None | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def extract_moe_lora_weights_for_grouped_mm( | ||||||||||||||||||||||||||||||||
| wrapper, | ||||||||||||||||||||||||||||||||
| weight_A: torch.Tensor, | ||||||||||||||||||||||||||||||||
| weight_B: torch.Tensor, | ||||||||||||||||||||||||||||||||
| scaling, | ||||||||||||||||||||||||||||||||
| num_experts: int, | ||||||||||||||||||||||||||||||||
| *, | ||||||||||||||||||||||||||||||||
| experts_module=None, | ||||||||||||||||||||||||||||||||
| input_dim=None, | ||||||||||||||||||||||||||||||||
| output_dim=None, | ||||||||||||||||||||||||||||||||
| model_name: str = "MoE", | ||||||||||||||||||||||||||||||||
| enable_logging: bool = None, | ||||||||||||||||||||||||||||||||
| logger_obj=None, | ||||||||||||||||||||||||||||||||
| ) -> Tuple[torch.Tensor, torch.Tensor, float, int]: | ||||||||||||||||||||||||||||||||
| total_rank = weight_A.shape[0] | ||||||||||||||||||||||||||||||||
| rank_per_expert = total_rank // num_experts | ||||||||||||||||||||||||||||||||
| dim_A = weight_A.shape[1] | ||||||||||||||||||||||||||||||||
| dim_B = weight_B.shape[0] | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| if num_experts <= 1: | ||||||||||||||||||||||||||||||||
| return weight_A.T, weight_B.T, scaling, num_experts | ||||||||||||||||||||||||||||||||
|
Comment on lines
+423
to
+429
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current logic for For
Suggested change
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| if input_dim is None or output_dim is None: | ||||||||||||||||||||||||||||||||
| inferred_input_dim, inferred_output_dim = _get_moe_lora_io_dims( | ||||||||||||||||||||||||||||||||
| wrapper, experts_module=experts_module, | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
| if input_dim is None: | ||||||||||||||||||||||||||||||||
| input_dim = inferred_input_dim | ||||||||||||||||||||||||||||||||
| if output_dim is None: | ||||||||||||||||||||||||||||||||
| output_dim = inferred_output_dim | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| canonical_match = ( | ||||||||||||||||||||||||||||||||
| input_dim is not None | ||||||||||||||||||||||||||||||||
| and output_dim is not None | ||||||||||||||||||||||||||||||||
| and dim_A == input_dim | ||||||||||||||||||||||||||||||||
| and dim_B == output_dim | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
| reversed_match = ( | ||||||||||||||||||||||||||||||||
| input_dim is not None | ||||||||||||||||||||||||||||||||
| and output_dim is not None | ||||||||||||||||||||||||||||||||
| and dim_A == output_dim | ||||||||||||||||||||||||||||||||
| and dim_B == input_dim | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| if canonical_match and reversed_match: | ||||||||||||||||||||||||||||||||
| if bool(getattr(wrapper, "_did_swap_in_out_features", False)): | ||||||||||||||||||||||||||||||||
| first_weight, second_weight = _reversed_lora_weights_for_grouped_mm( | ||||||||||||||||||||||||||||||||
| weight_A, weight_B, num_experts, rank_per_expert, dim_A, dim_B, | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||
| first_weight, second_weight = _canonical_lora_weights_for_grouped_mm( | ||||||||||||||||||||||||||||||||
| weight_A, weight_B, num_experts, rank_per_expert, dim_A, dim_B, | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
| return first_weight, second_weight, scaling, num_experts | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| if canonical_match: | ||||||||||||||||||||||||||||||||
| first_weight, second_weight = _canonical_lora_weights_for_grouped_mm( | ||||||||||||||||||||||||||||||||
| weight_A, weight_B, num_experts, rank_per_expert, dim_A, dim_B, | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
| return first_weight, second_weight, scaling, num_experts | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| if reversed_match: | ||||||||||||||||||||||||||||||||
| first_weight, second_weight = _reversed_lora_weights_for_grouped_mm( | ||||||||||||||||||||||||||||||||
| weight_A, weight_B, num_experts, rank_per_expert, dim_A, dim_B, | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
| return first_weight, second_weight, scaling, num_experts | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| if logger_obj is not None: | ||||||||||||||||||||||||||||||||
| if enable_logging is None: | ||||||||||||||||||||||||||||||||
| enable_logging = os.environ.get("UNSLOTH_ENABLE_LOGGING", "0") == "1" | ||||||||||||||||||||||||||||||||
| if enable_logging and (input_dim is not None or output_dim is not None): | ||||||||||||||||||||||||||||||||
| logger_obj.warning( | ||||||||||||||||||||||||||||||||
| f"Unsloth: {model_name} LoRA extractor could not match either layout " | ||||||||||||||||||||||||||||||||
| f"(weight_A={tuple(weight_A.shape)}, weight_B={tuple(weight_B.shape)}, " | ||||||||||||||||||||||||||||||||
| f"expected input_dim={input_dim}, output_dim={output_dim}, " | ||||||||||||||||||||||||||||||||
| f"num_experts={num_experts}). Falling back to canonical layout. " | ||||||||||||||||||||||||||||||||
| "If this is a new PEFT version, the LoRA delta may be wrong." | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| first_weight, second_weight = _canonical_lora_weights_for_grouped_mm( | ||||||||||||||||||||||||||||||||
| weight_A, weight_B, num_experts, rank_per_expert, dim_A, dim_B, | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
| return first_weight, second_weight, scaling, num_experts | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def _extract_lora_from_wrapper( | ||||||||||||||||||||||||||||||||
| wrapper, adapter_name: str = "default", experts_module=None | ||||||||||||||||||||||||||||||||
| ) -> Optional[Tuple[torch.Tensor, torch.Tensor, float, int]]: | ||||||||||||||||||||||||||||||||
|
|
@@ -404,32 +561,15 @@ def _extract_lora_from_wrapper( | |||||||||||||||||||||||||||||||
| if extractor_fn is not None: | ||||||||||||||||||||||||||||||||
| return extractor_fn(wrapper, weight_A, weight_B, scaling, num_experts) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # DEFAULT BEHAVIOR (Standard Format / Non-MoE) | ||||||||||||||||||||||||||||||||
| if num_experts > 1: | ||||||||||||||||||||||||||||||||
| total_rank = weight_A.shape[0] | ||||||||||||||||||||||||||||||||
| rank_per_expert = total_rank // num_experts | ||||||||||||||||||||||||||||||||
| dim1 = weight_A.shape[1] | ||||||||||||||||||||||||||||||||
| dim2 = weight_B.shape[0] | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # STANDARD FORMAT (Qwen3-MoE / GLM4): | ||||||||||||||||||||||||||||||||
| # Base weights are (E, out_dim, in_dim) for F.linear. | ||||||||||||||||||||||||||||||||
| # LoRA weights follow PEFT: weight_A is (E*R, in_dim), weight_B is (out_dim, E*R). | ||||||||||||||||||||||||||||||||
| # We need X @ (E, in_dim, R) @ (E, R, out_dim). | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # first_weight: (E, in_dim, R) - from lora_A | ||||||||||||||||||||||||||||||||
| # second_weight: (E, R, out_dim) - from lora_B | ||||||||||||||||||||||||||||||||
| first_weight = weight_A.view(num_experts, rank_per_expert, dim1) | ||||||||||||||||||||||||||||||||
| first_weight = first_weight.permute(0, 2, 1).contiguous() # (E, dim1, R) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # second_weight (B): (E, R, out_dim) | ||||||||||||||||||||||||||||||||
| second_weight = weight_B.view(dim2, num_experts, rank_per_expert) | ||||||||||||||||||||||||||||||||
| second_weight = second_weight.permute(1, 2, 0).contiguous() # (E, R, dim2) | ||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||
| # Non-MoE case: return weights for X @ A.T @ B.T | ||||||||||||||||||||||||||||||||
| first_weight = weight_A.T # (dim1, R) | ||||||||||||||||||||||||||||||||
| second_weight = weight_B.T # (R, dim2) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| return first_weight, second_weight, scaling, num_experts | ||||||||||||||||||||||||||||||||
| return extract_moe_lora_weights_for_grouped_mm( | ||||||||||||||||||||||||||||||||
| wrapper, | ||||||||||||||||||||||||||||||||
| weight_A, | ||||||||||||||||||||||||||||||||
| weight_B, | ||||||||||||||||||||||||||||||||
| scaling, | ||||||||||||||||||||||||||||||||
| num_experts, | ||||||||||||||||||||||||||||||||
| experts_module=experts_module, | ||||||||||||||||||||||||||||||||
| model_name="MoE", | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
| except Exception: | ||||||||||||||||||||||||||||||||
| return None | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The shared extractor now infers MoE input/output dims from parameter shape and, when
_unsloth_grouped_mm_formatis not set, assumes a(E, out_dim, in_dim)layout by returningshape[-1], shape[-2]. GLM4 previously had explicit handling for its transposed(E, in_dim, out_dim)cases, but this commit removed that model-specific logic and routes through the generic path, so transposed GLM4 checkpoints can be misclassified as canonical. In that scenariofirst_weightis built on the wrong axis, which can cause grouped-mm dimension mismatches or silently wrong LoRA deltas during training.Useful? React with 👍 / 👎.