diff --git a/unsloth_zoo/compiler.py b/unsloth_zoo/compiler.py index e3d42fda2..ebf09124e 100644 --- a/unsloth_zoo/compiler.py +++ b/unsloth_zoo/compiler.py @@ -1151,8 +1151,14 @@ def create_standalone_class( for line in lines: stripped = line.strip() if stripped.startswith("@"): - if "use_experts_implementation" in stripped: - logger.info(f'Unsloth: stripped use_experts_implementation decorator from {module}') + if ( + "use_experts_implementation" in stripped + or "use_kernel_forward_from_hub" in stripped + or "use_kernelized_func" in stripped + or stripped.startswith("@auto_docstring") + ): + decorator_name = stripped.split("(")[0].lstrip("@") + logger.info(f"Unsloth: stripped {decorator_name} decorator from {module}") continue # Strip it else: logger.warning(f"Unsloth: Warning: Unknown decorator {stripped} found for {module}.") @@ -1269,6 +1275,7 @@ def create_standalone_class( # Remove @auto_docstring source = re.sub(r"@auto_docstring[\s]{0,}(\([^\)]{0,}\))?", "", source) + source = re.sub(r"@use_kernelized_func[\s]{0,}(\([^\)]{0,}\))?", "", source) source = re.sub(r"@check_model_inputs[\s]{0,}(\([^\)]{0,}\))?", "", source) # source = source.replace("@auto_docstring", "") diff --git a/unsloth_zoo/temporary_patches/gpt_oss.py b/unsloth_zoo/temporary_patches/gpt_oss.py index 1388d78cc..0b0851ab1 100644 --- a/unsloth_zoo/temporary_patches/gpt_oss.py +++ b/unsloth_zoo/temporary_patches/gpt_oss.py @@ -648,6 +648,26 @@ def _load_from_state_dict( ) +def patch_gpt_oss_compiler_exports(): + model_name = os.environ.get("UNSLOTH_MODEL_NAME", "").replace("-", "_") + if "gpt_oss" not in model_name: + return + try: + import transformers.models.gpt_oss.modeling_gpt_oss + except Exception as e: + raise_error("transformers.models.gpt_oss.modeling_gpt_oss", e) + return + + # Export helpers so compiler generated GPT-OSS modules can resolve symbols. + m = transformers.models.gpt_oss.modeling_gpt_oss + m.ParameterModule = ParameterModule + m.swiglu_torch_forward = swiglu_torch_forward + m.dtype_from_config = dtype_from_config + m.transformers_version = transformers_version + m.Version = Version +TEMPORARY_PATCHES.append(patch_gpt_oss_compiler_exports) + + class GptOssExperts(nn.Module): """ GPT OSS MoE Experts layer with 3D stacked parameters. @@ -1316,15 +1336,19 @@ def _should_use_gpt_oss_bnb4bit() -> bool: Default: True when load_in_4bit is active. Set UNSLOTH_GPT_OSS_BNB4BIT_DISABLE=1 to force BF16 path. """ - if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): + if "gpt_oss" not in _normalized_unsloth_model_name(): return False - if "_load_in_4bit_" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): + if "_load_in_4bit_" not in _normalized_unsloth_model_name(): return False return os.environ.get("UNSLOTH_GPT_OSS_BNB4BIT_DISABLE", "0") != "1" def _is_gpt_oss_4bit_load() -> bool: - return "_load_in_4bit_" in os.environ.get("UNSLOTH_MODEL_NAME", "") + return "_load_in_4bit_" in _normalized_unsloth_model_name() + + +def _normalized_unsloth_model_name() -> str: + return os.environ.get("UNSLOTH_MODEL_NAME", "").replace("-", "_") def _is_transformers_v5() -> bool: @@ -1340,7 +1364,7 @@ def patch_gpt_oss_moe_for_lora(): IMPORTANT: We only patch the forward method, NOT replace the entire class. This preserves the original class structure so weights load correctly. """ - if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): + if "gpt_oss" not in _normalized_unsloth_model_name(): return if _is_gpt_oss_4bit_load() or _should_use_gpt_oss_bnb4bit(): # 4-bit loads should keep quantized weights and use default PEFT LoRA. @@ -1774,8 +1798,8 @@ def patch_gpt_oss_linearized(): Patch GPT OSS for 4bit loading with grouped_mm support. Only patches the GptOssExperts forward method - keeps original classes for proper weight loading. """ - if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return - if "_load_in_4bit_" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return + if "gpt_oss" not in _normalized_unsloth_model_name(): return + if "_load_in_4bit_" not in _normalized_unsloth_model_name(): return if _should_use_gpt_oss_bnb4bit(): return try: import transformers.models.gpt_oss.modeling_gpt_oss @@ -1813,7 +1837,7 @@ def experts_forward( def patch_GptOssAttention(): if os.environ.get("UNSLOTH_ENABLE_FLEX_ATTENTION", "1") == "0": return - if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return + if "gpt_oss" not in _normalized_unsloth_model_name(): return try: from ..flex_attention import ( flex_attention_with_sink, @@ -2054,7 +2078,7 @@ def forward( def patch_GptOssModel(): if os.environ.get("UNSLOTH_ENABLE_FLEX_ATTENTION", "1") == "0": return - if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return + if "gpt_oss" not in _normalized_unsloth_model_name(): return try: import transformers.models.gpt_oss.modeling_gpt_oss transformers.models.gpt_oss.modeling_gpt_oss.GptOssModel @@ -2075,12 +2099,25 @@ def patch_GptOssModel(): import transformers.generation.utils def wrap(f): def return_attention_mask(*args, **kwargs): - if kwargs["input_embeds"].requires_grad: + input_embeds = kwargs.get("input_embeds", None) + if input_embeds is None: + input_embeds = kwargs.get("inputs_embeds", None) + if input_embeds is None: + for arg in args: + if type(arg) is torch.Tensor and arg.is_floating_point(): + input_embeds = arg + break + + if input_embeds is not None and input_embeds.requires_grad: if "attention_mask" in kwargs: return kwargs["attention_mask"] for arg in args: - if type(arg) is torch.Tensor and arg.dtype == torch.int32: + if ( + type(arg) is torch.Tensor and + arg.dtype in (torch.int32, torch.int64, torch.bool) + ): return arg + return f(*args, **kwargs) else: # Eager return f(*args, **kwargs) @@ -2739,7 +2776,7 @@ def patch_gpt_oss_config(): def patch_gpt_oss_init_weights_modulelist_fix(): - if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): + if "gpt_oss" not in _normalized_unsloth_model_name(): return try: import transformers.models.gpt_oss.modeling_gpt_oss @@ -2784,7 +2821,7 @@ def patch_gpt_oss_for_grpo(): When UNSLOTH_RETURN_HIDDEN_STATES=1, return hidden_states instead of logits. This fixes the matrix multiplication dimension mismatch issue in GRPO training. """ - if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): + if "gpt_oss" not in _normalized_unsloth_model_name(): return try: