Skip to content
26 changes: 17 additions & 9 deletions unsloth_zoo/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1151,8 +1151,14 @@ def create_standalone_class(
for line in lines:
stripped = line.strip()
if stripped.startswith("@"):
if "use_experts_implementation" in stripped:
logger.info(f'Unsloth: stripped use_experts_implementation decorator from {module}')
if (
"use_experts_implementation" in stripped
or "use_kernel_forward_from_hub" in stripped
or "use_kernelized_func" in stripped
or stripped.startswith("@auto_docstring")
):
decorator_name = stripped.split("(")[0].lstrip("@")
logger.info(f"Unsloth: stripped {decorator_name} decorator from {module}")
continue # Strip it
else:
logger.warning(f"Unsloth: Warning: Unknown decorator {stripped} found for {module}.")
Expand All @@ -1165,13 +1171,14 @@ def create_standalone_class(
# Check if forward was replaced by a temporary patch (renamed function)
# In this case, keep the patched source as-is and replace the class forward body.
patched_forward_info = None
func_match = re.search(r"def\s+(\w+)\s*\(", forward_source)
if func_match and func_match.group(1) != "forward":
# Find original forward in class to replace it
orig_fwd = re.search(r"(\n\s+def\s+forward\s*\([^)]*\)[^:]*:.*?)(?=\n\s+def\s|\n\s+@|\Z)", full_class, re.DOTALL)
if orig_fwd:
patched_forward_info = (func_match.group(1), orig_fwd.group(1))
disable = None # Keep patched source as-is for renamed forward replacements
if "@torch.compiler.disable" in forward_source:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Detect renamed forward patches without disable decorator

Limiting renamed-forward detection to sources containing @torch.compiler.disable skips valid patched forwards that are renamed but undecorated (for example patch_function(DeepseekV3MoE, "forward", patched_moe_forward) in temporary_patches/deepseek_v3_moe.py). When this branch is skipped, create_standalone_class no longer swaps the class’s original forward with the patched implementation, so compiled modules silently fall back to stale/original forward logic and lose the runtime patch behavior.

Useful? React with 👍 / 👎.

func_match = re.search(r"def\s+(\w+)\s*\(", forward_source)
if func_match and func_match.group(1) != "forward":
# Find original forward in class to replace it
orig_fwd = re.search(r"(\n\s+def\s+forward\s*\([^)]*\)[^:]*:.*?)(?=\n\s+def\s|\n\s+@|\Z)", full_class, re.DOTALL)
if orig_fwd:
patched_forward_info = (func_match.group(1), orig_fwd.group(1))
disable = None # Keep patched source as-is for renamed forward replacements

# Replace function name with module-specific name
if patched_forward_info:
Expand Down Expand Up @@ -1269,6 +1276,7 @@ def create_standalone_class(

# Remove @auto_docstring
source = re.sub(r"@auto_docstring[\s]{0,}(\([^\)]{0,}\))?", "", source)
source = re.sub(r"@use_kernelized_func[\s]{0,}(\([^\)]{0,}\))?", "", source)
source = re.sub(r"@check_model_inputs[\s]{0,}(\([^\)]{0,}\))?", "", source)
# source = source.replace("@auto_docstring", "")

Expand Down
61 changes: 49 additions & 12 deletions unsloth_zoo/temporary_patches/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,26 @@ def _load_from_state_dict(
)


def patch_gpt_oss_compiler_exports():
model_name = os.environ.get("UNSLOTH_MODEL_NAME", "").replace("-", "_")
if "gpt_oss" not in model_name:
return
try:
import transformers.models.gpt_oss.modeling_gpt_oss
except Exception as e:
raise_error("transformers.models.gpt_oss.modeling_gpt_oss", e)
return

# Export helpers so compiler generated GPT-OSS modules can resolve symbols.
m = transformers.models.gpt_oss.modeling_gpt_oss
m.ParameterModule = ParameterModule
m.swiglu_torch_forward = swiglu_torch_forward
m.dtype_from_config = dtype_from_config
m.transformers_version = transformers_version
m.Version = Version
TEMPORARY_PATCHES.append(patch_gpt_oss_compiler_exports)


class GptOssExperts(nn.Module):
"""
GPT OSS MoE Experts layer with 3D stacked parameters.
Expand Down Expand Up @@ -1316,15 +1336,19 @@ def _should_use_gpt_oss_bnb4bit() -> bool:
Default: True when load_in_4bit is active.
Set UNSLOTH_GPT_OSS_BNB4BIT_DISABLE=1 to force BF16 path.
"""
if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""):
if "gpt_oss" not in _normalized_unsloth_model_name():
return False
if "_load_in_4bit_" not in os.environ.get("UNSLOTH_MODEL_NAME", ""):
if "_load_in_4bit_" not in _normalized_unsloth_model_name():
return False
return os.environ.get("UNSLOTH_GPT_OSS_BNB4BIT_DISABLE", "0") != "1"


def _is_gpt_oss_4bit_load() -> bool:
return "_load_in_4bit_" in os.environ.get("UNSLOTH_MODEL_NAME", "")
return "_load_in_4bit_" in _normalized_unsloth_model_name()


def _normalized_unsloth_model_name() -> str:
return os.environ.get("UNSLOTH_MODEL_NAME", "").replace("-", "_")


def _is_transformers_v5() -> bool:
Expand All @@ -1340,7 +1364,7 @@ def patch_gpt_oss_moe_for_lora():
IMPORTANT: We only patch the forward method, NOT replace the entire class.
This preserves the original class structure so weights load correctly.
"""
if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""):
if "gpt_oss" not in _normalized_unsloth_model_name():
return
if _is_gpt_oss_4bit_load() or _should_use_gpt_oss_bnb4bit():
# 4-bit loads should keep quantized weights and use default PEFT LoRA.
Expand Down Expand Up @@ -1774,8 +1798,8 @@ def patch_gpt_oss_linearized():
Patch GPT OSS for 4bit loading with grouped_mm support.
Only patches the GptOssExperts forward method - keeps original classes for proper weight loading.
"""
if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return
if "_load_in_4bit_" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return
if "gpt_oss" not in _normalized_unsloth_model_name(): return
if "_load_in_4bit_" not in _normalized_unsloth_model_name(): return
if _should_use_gpt_oss_bnb4bit(): return
try:
import transformers.models.gpt_oss.modeling_gpt_oss
Expand Down Expand Up @@ -1813,7 +1837,7 @@ def experts_forward(

def patch_GptOssAttention():
if os.environ.get("UNSLOTH_ENABLE_FLEX_ATTENTION", "1") == "0": return
if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return
if "gpt_oss" not in _normalized_unsloth_model_name(): return
try:
from ..flex_attention import (
flex_attention_with_sink,
Expand Down Expand Up @@ -2054,7 +2078,7 @@ def forward(

def patch_GptOssModel():
if os.environ.get("UNSLOTH_ENABLE_FLEX_ATTENTION", "1") == "0": return
if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return
if "gpt_oss" not in _normalized_unsloth_model_name(): return
try:
import transformers.models.gpt_oss.modeling_gpt_oss
transformers.models.gpt_oss.modeling_gpt_oss.GptOssModel
Expand All @@ -2075,12 +2099,25 @@ def patch_GptOssModel():
import transformers.generation.utils
def wrap(f):
def return_attention_mask(*args, **kwargs):
if kwargs["input_embeds"].requires_grad:
input_embeds = kwargs.get("input_embeds", None)
if input_embeds is None:
input_embeds = kwargs.get("inputs_embeds", None)
if input_embeds is None:
for arg in args:
if type(arg) is torch.Tensor and arg.is_floating_point():
input_embeds = arg
break

if input_embeds is not None and input_embeds.requires_grad:
if "attention_mask" in kwargs:
return kwargs["attention_mask"]
for arg in args:
if type(arg) is torch.Tensor and arg.dtype == torch.int32:
if (
type(arg) is torch.Tensor and
arg.dtype in (torch.int32, torch.int64, torch.bool)
):
return arg
return f(*args, **kwargs)
else:
# Eager
return f(*args, **kwargs)
Expand Down Expand Up @@ -2739,7 +2776,7 @@ def patch_gpt_oss_config():


def patch_gpt_oss_init_weights_modulelist_fix():
if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""):
if "gpt_oss" not in _normalized_unsloth_model_name():
return
try:
import transformers.models.gpt_oss.modeling_gpt_oss
Expand Down Expand Up @@ -2784,7 +2821,7 @@ def patch_gpt_oss_for_grpo():
When UNSLOTH_RETURN_HIDDEN_STATES=1, return hidden_states instead of logits.
This fixes the matrix multiplication dimension mismatch issue in GRPO training.
"""
if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""):
if "gpt_oss" not in _normalized_unsloth_model_name():
return

try:
Expand Down
Loading