diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index e3d42fda2..1fa400251 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1151,8 +1151,14 @@ def create_standalone_class( for line in lines: stripped = line.strip() if stripped.startswith("@"): - if "use_experts_implementation" in stripped: - logger.info(f'Unsloth: stripped use_experts_implementation decorator from {module}') + if ( + "use_experts_implementation" in stripped + or "use_kernel_forward_from_hub" in stripped + or "use_kernelized_func" in stripped + or stripped.startswith("@auto_docstring") + ): + decorator_name = stripped.split("(")[0].lstrip("@") + logger.info(f"Unsloth: stripped {decorator_name} decorator from {module}") continue # Strip it else: logger.warning(f"Unsloth: Warning: Unknown decorator {stripped} found for {module}.") @@ -1165,13 +1171,14 @@ def create_standalone_class( # Check if forward was replaced by a temporary patch (renamed function) # In this case, keep the patched source as-is and replace the class forward body. patched_forward_info = None - func_match = re.search(r"def\s+(\w+)\s*\(", forward_source) - if func_match and func_match.group(1) != "forward": - # Find original forward in class to replace it - orig_fwd = re.search(r"(\n\s+def\s+forward\s*\([^)]*\)[^:]*:.*?)(?=\n\s+def\s|\n\s+@|\Z)", full_class, re.DOTALL) - if orig_fwd: - patched_forward_info = (func_match.group(1), orig_fwd.group(1)) - disable = None # Keep patched source as-is for renamed forward replacements + if "@torch.compiler.disable" in forward_source: + func_match = re.search(r"def\s+(\w+)\s*\(", forward_source) + if func_match and func_match.group(1) != "forward": + # Find original forward in class to replace it + orig_fwd = re.search(r"(\n\s+def\s+forward\s*\([^)]*\)[^:]*:.*?)(?=\n\s+def\s|\n\s+@|\Z)", full_class, re.DOTALL) + if orig_fwd: + patched_forward_info = (func_match.group(1), orig_fwd.group(1)) + disable = None # Keep patched source as-is for renamed forward replacements # Replace function name with module-specific name if patched_forward_info: @@ -1269,6 +1276,7 @@ def create_standalone_class( # Remove @auto_docstring source = re.sub(r"@auto_docstring[\s]{0,}(\([^\)]{0,}\))?", "", source) + source = re.sub(r"@use_kernelized_func[\s]{0,}(\([^\)]{0,}\))?", "", source) source = re.sub(r"@check_model_inputs[\s]{0,}(\([^\)]{0,}\))?", "", source) # source = source.replace("@auto_docstring", "") diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index 1388d78cc..0b0851ab1 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -648,6 +648,26 @@ def _load_from_state_dict( ) +def patch_gpt_oss_compiler_exports(): + model_name = os.environ.get("UNSLOTH_MODEL_NAME", "").replace("-", "_") + if "gpt_oss" not in model_name: + return + try: + import transformers.models.gpt_oss.modeling_gpt_oss + except Exception as e: + raise_error("transformers.models.gpt_oss.modeling_gpt_oss", e) + return + + # Export helpers so compiler generated GPT-OSS modules can resolve symbols. + m = transformers.models.gpt_oss.modeling_gpt_oss + m.ParameterModule = ParameterModule + m.swiglu_torch_forward = swiglu_torch_forward + m.dtype_from_config = dtype_from_config + m.transformers_version = transformers_version + m.Version = Version +TEMPORARY_PATCHES.append(patch_gpt_oss_compiler_exports) + + class GptOssExperts(nn.Module): """ GPT OSS MoE Experts layer with 3D stacked parameters. @@ -1316,15 +1336,19 @@ def _should_use_gpt_oss_bnb4bit() -> bool: Default: True when load_in_4bit is active. Set UNSLOTH_GPT_OSS_BNB4BIT_DISABLE=1 to force BF16 path. """ - if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): + if "gpt_oss" not in _normalized_unsloth_model_name(): return False - if "_load_in_4bit_" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): + if "_load_in_4bit_" not in _normalized_unsloth_model_name(): return False return os.environ.get("UNSLOTH_GPT_OSS_BNB4BIT_DISABLE", "0") != "1" def _is_gpt_oss_4bit_load() -> bool: - return "_load_in_4bit_" in os.environ.get("UNSLOTH_MODEL_NAME", "") + return "_load_in_4bit_" in _normalized_unsloth_model_name() + + +def _normalized_unsloth_model_name() -> str: + return os.environ.get("UNSLOTH_MODEL_NAME", "").replace("-", "_") def _is_transformers_v5() -> bool: @@ -1340,7 +1364,7 @@ def patch_gpt_oss_moe_for_lora(): IMPORTANT: We only patch the forward method, NOT replace the entire class. This preserves the original class structure so weights load correctly. """ - if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): + if "gpt_oss" not in _normalized_unsloth_model_name(): return if _is_gpt_oss_4bit_load() or _should_use_gpt_oss_bnb4bit(): # 4-bit loads should keep quantized weights and use default PEFT LoRA. @@ -1774,8 +1798,8 @@ def patch_gpt_oss_linearized(): Patch GPT OSS for 4bit loading with grouped_mm support. Only patches the GptOssExperts forward method - keeps original classes for proper weight loading. """ - if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return - if "_load_in_4bit_" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return + if "gpt_oss" not in _normalized_unsloth_model_name(): return + if "_load_in_4bit_" not in _normalized_unsloth_model_name(): return if _should_use_gpt_oss_bnb4bit(): return try: import transformers.models.gpt_oss.modeling_gpt_oss @@ -1813,7 +1837,7 @@ def experts_forward( def patch_GptOssAttention(): if os.environ.get("UNSLOTH_ENABLE_FLEX_ATTENTION", "1") == "0": return - if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return + if "gpt_oss" not in _normalized_unsloth_model_name(): return try: from ..flex_attention import ( flex_attention_with_sink, @@ -2054,7 +2078,7 @@ def forward( def patch_GptOssModel(): if os.environ.get("UNSLOTH_ENABLE_FLEX_ATTENTION", "1") == "0": return - if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return + if "gpt_oss" not in _normalized_unsloth_model_name(): return try: import transformers.models.gpt_oss.modeling_gpt_oss transformers.models.gpt_oss.modeling_gpt_oss.GptOssModel @@ -2075,12 +2099,25 @@ def patch_GptOssModel(): import transformers.generation.utils def wrap(f): def return_attention_mask(*args, **kwargs): - if kwargs["input_embeds"].requires_grad: + input_embeds = kwargs.get("input_embeds", None) + if input_embeds is None: + input_embeds = kwargs.get("inputs_embeds", None) + if input_embeds is None: + for arg in args: + if type(arg) is torch.Tensor and arg.is_floating_point(): + input_embeds = arg + break + + if input_embeds is not None and input_embeds.requires_grad: if "attention_mask" in kwargs: return kwargs["attention_mask"] for arg in args: - if type(arg) is torch.Tensor and arg.dtype == torch.int32: + if ( + type(arg) is torch.Tensor and + arg.dtype in (torch.int32, torch.int64, torch.bool) + ): return arg + return f(*args, **kwargs) else: # Eager return f(*args, **kwargs) @@ -2739,7 +2776,7 @@ def patch_gpt_oss_config(): def patch_gpt_oss_init_weights_modulelist_fix(): - if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): + if "gpt_oss" not in _normalized_unsloth_model_name(): return try: import transformers.models.gpt_oss.modeling_gpt_oss @@ -2784,7 +2821,7 @@ def patch_gpt_oss_for_grpo(): When UNSLOTH_RETURN_HIDDEN_STATES=1, return hidden_states instead of logits. This fixes the matrix multiplication dimension mismatch issue in GRPO training. """ - if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): + if "gpt_oss" not in _normalized_unsloth_model_name(): return try: diff --git a/unsloth_zoo/temporary_patches/misc.py b/unsloth_zoo/temporary_patches/misc.py index d42b8445b..d1f3c3735 100644 --- a/unsloth_zoo/temporary_patches/misc.py +++ b/unsloth_zoo/temporary_patches/misc.py @@ -1136,334 +1136,6 @@ def patched_forward(self, image_features): TEMPORARY_PATCHES.append(patch_Lfm2VlMultiModalProjector) -def patch_peft_moe_4bit_paramwrapper_and_injection(): - """Patch PEFT for MoE 4bit parameter injection and ParamWrapper behavior.""" - try: - import torch.nn as nn - from contextlib import contextmanager - import peft.tuners.tuners_utils as tuners_utils - import peft.tuners.lora.layer as lora_layer - from peft.utils.integrations import get_bnb_param_type - except Exception: - return - - # 1) Prevent endless adapter injection loops by snapshotting named_modules traversal. - try: - base_tuner_cls = tuners_utils.BaseTuner - orig_inject = base_tuner_cls._inject_parameters - if not getattr(orig_inject, "_unsloth_moe_snapshot_named_modules", False): - source = inspect.getsource(orig_inject) - old = "for module_name, module in model.named_modules():" - new = "for module_name, module in list(model.named_modules()):" - if old in source and new not in source: - source = dedent(source).replace(old, new) - scope = dict(tuners_utils.__dict__) - exec(source, scope) - patched = scope["_inject_parameters"] - patched._unsloth_moe_snapshot_named_modules = True - patch_function(base_tuner_cls, "_inject_parameters", patched, match_level="relaxed") - except Exception: - pass - - # 2) Ensure LoRA proxy dequantizes 4bit params before adding LoRA delta. - try: - proxy_cls = lora_layer._LoraParameterProxy - orig_proxy_forward = proxy_cls.forward - if not getattr(orig_proxy_forward, "_unsloth_moe_4bit_proxy_forward", False): - def _patched_proxy_forward(self, W): - if get_bnb_param_type(W) == "4bit": - import bitsandbytes as bnb - W = bnb.functional.dequantize_4bit(W.data, W.quant_state) - return W + self.delta_weight - - _patched_proxy_forward._unsloth_moe_4bit_proxy_forward = True - patch_function(proxy_cls, "forward", _patched_proxy_forward, match_level="relaxed") - except Exception: - pass - - # 3) ParamWrapper fixes for 4bit params: shape inference, dtype placement, delta dtype, unsafe parametrization. - try: - param_wrapper_cls = lora_layer.ParamWrapper - - orig_get_in_out = param_wrapper_cls._get_in_out_features - if not getattr(orig_get_in_out, "_unsloth_moe_4bit_in_out", False): - def _patched_get_in_out_features(self, module): - param = self.get_param() - if get_bnb_param_type(param) == "4bit": - logical_shape = getattr(getattr(param, "quant_state", None), "shape", None) - if logical_shape is None: - raise ValueError( - f"lora.{self.__class__.__name__} got a 4bit parameter without quant_state.shape, cannot infer shape." - ) - if len(logical_shape) == 3: - num_experts, in_features, out_features = logical_shape - elif len(logical_shape) == 2: - num_experts, in_features, out_features = 1, logical_shape[1], logical_shape[0] - else: - raise ValueError( - f"lora.{self.__class__.__name__} was initialized with {len(logical_shape)} dimensional Parameter, but only 2d and 3d are supported." - ) - self.num_experts = num_experts - return in_features, out_features - return orig_get_in_out(self, module) - - _patched_get_in_out_features._unsloth_moe_4bit_in_out = True - patch_function(param_wrapper_cls, "_get_in_out_features", _patched_get_in_out_features, match_level="relaxed") - - orig_move_adapter = param_wrapper_cls._move_adapter_to_device_of_base_layer - if not getattr(orig_move_adapter, "_unsloth_moe_4bit_adapter_dtype", False): - def _patched_move_adapter_to_device(self, adapter_name: str, device: Optional[torch.device] = None): - if device is None: - param = self.get_param() - device = param.device - else: - param = self.get_param() - - target_dtype = None - if param.is_floating_point() or param.is_complex(): - if get_bnb_param_type(param) == "4bit": - target_dtype = getattr(getattr(param, "quant_state", None), "dtype", None) - else: - target_dtype = param.dtype - - meta = torch.device("meta") - for adapter_layer_name in self.adapter_layer_names + self.other_param_names: - adapter_layer = getattr(self, adapter_layer_name, None) - if not isinstance(adapter_layer, (nn.ModuleDict, nn.ParameterDict, lora_layer.BufferDict)): - continue - if adapter_name not in adapter_layer: - continue - if any(p.device == meta for p in adapter_layer.parameters()): - continue - if target_dtype is not None: - adapter_layer[adapter_name] = adapter_layer[adapter_name].to(device, dtype=target_dtype) - else: - adapter_layer[adapter_name] = adapter_layer[adapter_name].to(device) - - _patched_move_adapter_to_device._unsloth_moe_4bit_adapter_dtype = True - patch_function( - param_wrapper_cls, - "_move_adapter_to_device_of_base_layer", - _patched_move_adapter_to_device, - match_level="relaxed", - ) - - orig_get_delta = param_wrapper_cls.get_delta_weight - if not getattr(orig_get_delta, "_unsloth_moe_4bit_delta_dtype", False): - def _patched_get_delta_weight(self, adapter_name, *args, **kwargs): - delta_weight = orig_get_delta(self, adapter_name, *args, **kwargs) - param = self.get_param() - if get_bnb_param_type(param) == "4bit": - target_dtype = getattr(getattr(param, "quant_state", None), "dtype", None) - if target_dtype is not None: - delta_weight = delta_weight.to(param.device, target_dtype) - return delta_weight - - _patched_get_delta_weight._unsloth_moe_4bit_delta_dtype = True - patch_function(param_wrapper_cls, "get_delta_weight", _patched_get_delta_weight, match_level="relaxed") - - if not getattr(param_wrapper_cls._activate_lora, "_unsloth_moe_4bit_unsafe_parametrize", False): - @contextmanager - def _patched_activate_lora(self, active_adapters: list[str]): - if not active_adapters or not any(adapter in self.lora_A for adapter in active_adapters): - yield - return - - delta_weight = None - for active_adapter in active_adapters: - if active_adapter not in self.lora_A: - continue - if delta_weight is None: - delta_weight = self.get_delta_weight(active_adapter) - else: - delta_weight = delta_weight + self.get_delta_weight(active_adapter) - - base_layer = self.get_base_layer() - requires_grad_before = self.get_param().requires_grad - unsafe = get_bnb_param_type(self.get_param()) == "4bit" - nn.utils.parametrize.register_parametrization( - base_layer, self.parameter_name, lora_layer._LoraParameterProxy(delta_weight), unsafe=unsafe - ) - base_layer.parametrizations[self.parameter_name].original.requires_grad_(requires_grad_before) - try: - with nn.utils.parametrize.cached(): - yield - finally: - self._remove_parametrizations() - - _patched_activate_lora._unsloth_moe_4bit_unsafe_parametrize = True - patch_function(param_wrapper_cls, "_activate_lora", _patched_activate_lora, match_level="relaxed") - except Exception: - pass -pass -TEMPORARY_PATCHES.append(patch_peft_moe_4bit_paramwrapper_and_injection) - - -def patch_transformers_bnb4bit_moe_param_quantization(): - """Patch transformers bitsandbytes quantization to include MoE gate_up_proj/down_proj parameters.""" - try: - import bitsandbytes as bnb - from transformers.pytorch_utils import Conv1D - from transformers.quantizers.quantizers_utils import get_module_from_name - import transformers.quantizers.quantizer_bnb_4bit as quantizer_bnb_4bit - import transformers.integrations.bitsandbytes as bnb_integration - except Exception: - return - - # 1) Mark MoE expert parameters as quantizable in 4-bit path. - try: - quantizer_cls = quantizer_bnb_4bit.Bnb4BitHfQuantizer - orig_param_needs_quantization = quantizer_cls.param_needs_quantization - if not getattr(orig_param_needs_quantization, "_unsloth_moe_param_needs_quantization", False): - def _patched_param_needs_quantization(self, model, param_name: str, **kwargs): - if param_name.endswith(".gate_up_proj") or param_name.endswith(".down_proj"): - return True - return orig_param_needs_quantization(self, model, param_name, **kwargs) - - _patched_param_needs_quantization._unsloth_moe_param_needs_quantization = True - patch_function( - quantizer_cls, - "param_needs_quantization", - _patched_param_needs_quantization, - match_level="relaxed", - ) - except Exception: - pass - - # 2) Safe convert for non-Linear parameters (MoE expert tensors are nn.Parameter, not module.weight). - try: - quantize_cls = bnb_integration.Bnb4bitQuantize - orig_convert = quantize_cls.convert - if not getattr(orig_convert, "_unsloth_moe_4bit_convert", False): - def _patched_convert( - self, - input_dict: dict[str, list[torch.Tensor]], - full_layer_name: str | None = None, - model: torch.nn.Module | None = None, - **kwargs, - ) -> dict[str, torch.Tensor]: - value = list(input_dict.values())[0] - if isinstance(value, list): - if ( - (full_layer_name.endswith(".gate_up_proj") or full_layer_name.endswith(".down_proj")) - and len(value) > 1 - and hasattr(value[0], "dim") - and value[0].dim() == 2 - ): - value = torch.stack(value, dim=0) - else: - value = value[0] - - module, _ = get_module_from_name(model, full_layer_name) - source_cls = getattr(module, "source_cls", None) - if source_cls is not None: - try: - if issubclass(source_cls, Conv1D): - value = value.T - except TypeError: - pass - - old_value = model.get_parameter_or_buffer(full_layer_name) - if ( - (full_layer_name.endswith(".gate_up_proj") or full_layer_name.endswith(".down_proj")) - and hasattr(old_value, "shape") - and hasattr(value, "numel") - and len(old_value.shape) == 3 - and value.numel() == old_value.numel() - and tuple(value.shape) != tuple(old_value.shape) - ): - value = value.view(old_value.shape) - new_value = bnb.nn.Params4bit(value, requires_grad=False, **old_value.__dict__).to(value.device) - module._is_hf_initialized = True - return {full_layer_name: new_value} - - _patched_convert._unsloth_moe_4bit_convert = True - patch_function(quantize_cls, "convert", _patched_convert, match_level="relaxed") - except Exception: - pass -pass -TEMPORARY_PATCHES.append(patch_transformers_bnb4bit_moe_param_quantization) - - -def patch_transformers_moe_bnb4bit_dequantization(): - """Patch transformers.integrations.moe to dequantize Params4bit expert tensors before indexing/grouped-mm.""" - try: - import transformers.integrations.moe as moe - except Exception: - return - - def _patch_batched_mm(): - try: - f = moe.batched_mm_experts_forward - source = inspect.getsource(f) - if "dequantize_4bit(" in source: - return - old = ( - "selected_gate_up = self.gate_up_proj[expert_ids_clamped]\n" - " selected_down = self.down_proj[expert_ids_clamped]" - ) - new = ( - "gate_up_proj_is_quantized = is_bitsandbytes_available() and self.gate_up_proj.__class__.__name__ == \"Params4bit\"\n" - " down_proj_is_quantized = is_bitsandbytes_available() and self.down_proj.__class__.__name__ == \"Params4bit\"\n" - "\n" - " if gate_up_proj_is_quantized:\n" - " selected_gate_up = bnb.functional.dequantize_4bit(self.gate_up_proj.data, self.gate_up_proj.quant_state)[expert_ids_clamped]\n" - " else:\n" - " selected_gate_up = self.gate_up_proj[expert_ids_clamped]\n" - "\n" - " if down_proj_is_quantized:\n" - " selected_down = bnb.functional.dequantize_4bit(self.down_proj.data, self.down_proj.quant_state)[expert_ids_clamped]\n" - " else:\n" - " selected_down = self.down_proj[expert_ids_clamped]" - ) - if old not in source: - return - source = dedent(source).replace(old, new) - scope = dict(moe.__dict__) - exec(source, scope) - patch_function(moe, "batched_mm_experts_forward", scope["batched_mm_experts_forward"], match_level="relaxed") - except Exception: - pass - - def _patch_grouped_mm(): - try: - f = moe.grouped_mm_experts_forward - source = inspect.getsource(f) - if "dequantize_4bit(" in source: - return - old = ( - "selected_gate_up = self.gate_up_proj\n" - " selected_down = self.down_proj" - ) - new = ( - "gate_up_proj_is_quantized = is_bitsandbytes_available() and self.gate_up_proj.__class__.__name__ == \"Params4bit\"\n" - " down_proj_is_quantized = is_bitsandbytes_available() and self.down_proj.__class__.__name__ == \"Params4bit\"\n" - "\n" - " if gate_up_proj_is_quantized:\n" - " selected_gate_up = bnb.functional.dequantize_4bit(self.gate_up_proj.data, self.gate_up_proj.quant_state)\n" - " else:\n" - " selected_gate_up = self.gate_up_proj\n" - "\n" - " if down_proj_is_quantized:\n" - " selected_down = bnb.functional.dequantize_4bit(self.down_proj.data, self.down_proj.quant_state)\n" - " else:\n" - " selected_down = self.down_proj" - ) - if old not in source: - return - source = dedent(source).replace(old, new) - scope = dict(moe.__dict__) - exec(source, scope) - patch_function(moe, "grouped_mm_experts_forward", scope["grouped_mm_experts_forward"], match_level="relaxed") - except Exception: - pass - - _patch_batched_mm() - _patch_grouped_mm() -pass -TEMPORARY_PATCHES.append(patch_transformers_moe_bnb4bit_dequantization) - - def patch_peft_dispatch_bnb_4bit(): """Fix PEFT dispatch_bnb_4bit accessing compress_statistics on non-Params4bit weights. diff --git a/unsloth_zoo/temporary_patches/moe_utils.py b/unsloth_zoo/temporary_patches/moe_utils.py index a19a575f1..1954522a9 100644 --- a/unsloth_zoo/temporary_patches/moe_utils.py +++ b/unsloth_zoo/temporary_patches/moe_utils.py @@ -470,21 +470,6 @@ def _get_base_weight(param): if hasattr(param, "get_param"): return param.get_param() - # Auto-dequantize BitsAndBytes 4-bit packed MoE parameters for grouped_mm/LoRA forward - # (packed tensor shape is not the logical expert tensor shape). - quant_state = getattr(param, "quant_state", None) - if quant_state is not None: - param_cls_name = type(param).__name__ - if param_cls_name == "Params4bit": - try: - from bitsandbytes.functional import dequantize_4bit - dequantized = dequantize_4bit(param.data, quant_state=quant_state) - if hasattr(quant_state, "dtype") and quant_state.dtype is not None: - dequantized = dequantized.to(quant_state.dtype) - return dequantized.contiguous() - except Exception: - pass - # Handle Modules (Linear, etc.) if hasattr(param, "weight"): return param.weight