Skip to content
Open
11 changes: 9 additions & 2 deletions unsloth_zoo/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1151,8 +1151,14 @@ def create_standalone_class(
for line in lines:
stripped = line.strip()
if stripped.startswith("@"):
if "use_experts_implementation" in stripped:
logger.info(f'Unsloth: stripped use_experts_implementation decorator from {module}')
if (
"use_experts_implementation" in stripped
or "use_kernel_forward_from_hub" in stripped

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually nice. Never noticed this. Seems like a very recent change.
But this function is to strip the decorator. Why not have the decorator and use the kernel?

or "use_kernelized_func" in stripped
or stripped.startswith("@auto_docstring")
):
decorator_name = stripped.split("(")[0].lstrip("@")
logger.info(f"Unsloth: stripped {decorator_name} decorator from {module}")
continue # Strip it
else:
logger.warning(f"Unsloth: Warning: Unknown decorator {stripped} found for {module}.")
Expand Down Expand Up @@ -1269,6 +1275,7 @@ def create_standalone_class(

# Remove @auto_docstring
source = re.sub(r"@auto_docstring[\s]{0,}(\([^\)]{0,}\))?", "", source)
source = re.sub(r"@use_kernelized_func[\s]{0,}(\([^\)]{0,}\))?", "", source)
source = re.sub(r"@check_model_inputs[\s]{0,}(\([^\)]{0,}\))?", "", source)
# source = source.replace("@auto_docstring", "")

Expand Down
141 changes: 125 additions & 16 deletions unsloth_zoo/temporary_patches/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,13 +1094,105 @@ def patch_gpt_oss_bnb4bit_auto():
TEMPORARY_PATCHES.append(patch_gpt_oss_bnb4bit_auto)


_LOW_MEMORY_ACCELERATOR_BYTES = int(24 * 1024**3)


def _get_active_accelerator_index():
try:
if DEVICE_TYPE == "xpu":
if hasattr(torch, "xpu") and hasattr(torch.xpu, "current_device"):
return int(torch.xpu.current_device())
return 0
if hasattr(torch, "cuda") and hasattr(torch.cuda, "current_device"):
return int(torch.cuda.current_device())
except Exception:
pass
return 0


def _get_accelerator_total_memory_bytes():

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: We should put these in utils.py (away from temporary_patches) as these have potential to be reused.
Also there was a PR about DeviceContext iirc which might better handle this
Here it is unslothai/unsloth#3875

try:
device_index = _get_active_accelerator_index()
if DEVICE_TYPE == "xpu":
return int(torch.xpu.memory.mem_get_info(device_index)[-1])
return int(torch.cuda.memory.mem_get_info(device_index)[-1])
except Exception:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Catching a broad Exception can hide unexpected errors and make debugging more difficult. It's better to catch more specific exceptions that you expect to handle, such as RuntimeError, ImportError, or AttributeError in this context. This makes the code's intent clearer and more robust against unrelated issues.

Suggested change
except Exception:
except (RuntimeError, ImportError, AttributeError):

return None


def _get_effective_accelerator_memory_bytes():
total_memory = _get_accelerator_total_memory_bytes()
if total_memory is None:
return None
if DEVICE_TYPE != "xpu" and hasattr(torch.cuda, "get_per_process_memory_fraction"):
try:
device_index = _get_active_accelerator_index()
fraction = float(torch.cuda.get_per_process_memory_fraction(device_index))

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. Are we restricting this somewhere else or is this to respect some hypothetical user setting such a limit?

if 0.0 < fraction < 1.0:
return int(total_memory * fraction)
except Exception:
Comment thread
github-code-quality[bot] marked this conversation as resolved.
Fixed

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the previous comment, catching a broad Exception is discouraged. For torch.cuda.get_per_process_memory_fraction, it's better to specifically handle expected exceptions like RuntimeError or NotImplementedError to avoid masking other potential bugs.

Suggested change
except Exception:
except (RuntimeError, NotImplementedError):

pass
return total_memory


def _should_skip_transformers_allocator_warmup() -> bool:
"""
Skip transformers allocator warmup on low-memory accelerators.

`caching_allocator_warmup` can allocate large single chunks before weights
are loaded, which can OOM constrained GPUs.
"""
mode = os.environ.get("UNSLOTH_ALLOCATOR_WARMUP", "").strip().lower()
if mode in ("off", "disable", "0", "false"):
return True
if mode in ("on", "enable", "1", "true"):
return False

total_memory = _get_effective_accelerator_memory_bytes()
if total_memory is None:
return False
return total_memory <= _LOW_MEMORY_ACCELERATOR_BYTES


def patch_transformers_caching_allocator_warmup():
Comment thread
github-code-quality[bot] marked this conversation as resolved.
Fixed
Comment thread
github-code-quality[bot] marked this conversation as resolved.
Fixed
try:
import transformers.modeling_utils
except Exception as e:
return raise_error("transformers.modeling_utils", e)

warmup_fn = transformers.modeling_utils.caching_allocator_warmup
if hasattr(warmup_fn, "__unsloth_allocator_warmup_guarded__"):
return
# Backward compatibility with previous guard attribute.
if hasattr(warmup_fn, "__unsloth_gpt_oss_guarded__"):
return

def guarded_caching_allocator_warmup(model, expanded_device_map, hf_quantizer):
Comment thread
github-code-quality[bot] marked this conversation as resolved.
Fixed
Comment thread
github-code-quality[bot] marked this conversation as resolved.
Fixed
Comment thread
github-code-quality[bot] marked this conversation as resolved.
Fixed
Comment thread
github-code-quality[bot] marked this conversation as resolved.
Fixed
if _should_skip_transformers_allocator_warmup():
if UNSLOTH_ENABLE_LOGGING:
logger.warning_once(
"Unsloth: Skipping transformers caching_allocator_warmup "
"on low-memory accelerators (<24GB effective memory). "
"Set UNSLOTH_ALLOCATOR_WARMUP=on to keep warmup."
)
return
return warmup_fn(model, expanded_device_map, hf_quantizer)

guarded_caching_allocator_warmup.__unsloth_allocator_warmup_guarded__ = True
# Keep legacy marker so older checks still detect this as guarded.
guarded_caching_allocator_warmup.__unsloth_gpt_oss_guarded__ = True
transformers.modeling_utils.caching_allocator_warmup = guarded_caching_allocator_warmup


TEMPORARY_PATCHES.append(patch_transformers_caching_allocator_warmup)


# Combo kernels uses too much VRAM for low memory GPUs
from ..device_type import DEVICE_TYPE

if DEVICE_TYPE == "xpu":
device_memory = torch.xpu.memory.mem_get_info(0)[-1]
else:
device_memory = torch.cuda.memory.mem_get_info(0)[-1]
device_memory = _get_accelerator_total_memory_bytes()
if device_memory is None:
device_memory = 0
use_combo_kernels = False if device_memory/1024/1024/1024 <= 40 else True
fused_torch_compile_options = get_torch_compile_options(
epilogue_fusion = True,
Expand Down Expand Up @@ -1316,15 +1408,19 @@ def _should_use_gpt_oss_bnb4bit() -> bool:
Default: True when load_in_4bit is active.
Set UNSLOTH_GPT_OSS_BNB4BIT_DISABLE=1 to force BF16 path.
"""
if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""):
if "gpt_oss" not in _normalized_unsloth_model_name():

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see this being a repeat of #519 (some of it). We should try to consolidate to avoid merge conflicts

return False
if "_load_in_4bit_" not in os.environ.get("UNSLOTH_MODEL_NAME", ""):
if "_load_in_4bit_" not in _normalized_unsloth_model_name():
return False
return os.environ.get("UNSLOTH_GPT_OSS_BNB4BIT_DISABLE", "0") != "1"


def _is_gpt_oss_4bit_load() -> bool:
return "_load_in_4bit_" in os.environ.get("UNSLOTH_MODEL_NAME", "")
return "_load_in_4bit_" in _normalized_unsloth_model_name()


def _normalized_unsloth_model_name() -> str:
return os.environ.get("UNSLOTH_MODEL_NAME", "").replace("-", "_")


def _is_transformers_v5() -> bool:
Expand All @@ -1340,7 +1436,7 @@ def patch_gpt_oss_moe_for_lora():
IMPORTANT: We only patch the forward method, NOT replace the entire class.
This preserves the original class structure so weights load correctly.
"""
if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""):
if "gpt_oss" not in _normalized_unsloth_model_name():
return
if _is_gpt_oss_4bit_load() or _should_use_gpt_oss_bnb4bit():
# 4-bit loads should keep quantized weights and use default PEFT LoRA.
Expand Down Expand Up @@ -1774,8 +1870,8 @@ def patch_gpt_oss_linearized():
Patch GPT OSS for 4bit loading with grouped_mm support.
Only patches the GptOssExperts forward method - keeps original classes for proper weight loading.
"""
if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return
if "_load_in_4bit_" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return
if "gpt_oss" not in _normalized_unsloth_model_name(): return
if "_load_in_4bit_" not in _normalized_unsloth_model_name(): return
if _should_use_gpt_oss_bnb4bit(): return
try:
import transformers.models.gpt_oss.modeling_gpt_oss
Expand Down Expand Up @@ -1813,7 +1909,7 @@ def experts_forward(

def patch_GptOssAttention():
if os.environ.get("UNSLOTH_ENABLE_FLEX_ATTENTION", "1") == "0": return
if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return
if "gpt_oss" not in _normalized_unsloth_model_name(): return
try:
from ..flex_attention import (
flex_attention_with_sink,
Expand Down Expand Up @@ -2054,7 +2150,7 @@ def forward(

def patch_GptOssModel():
if os.environ.get("UNSLOTH_ENABLE_FLEX_ATTENTION", "1") == "0": return
if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return
if "gpt_oss" not in _normalized_unsloth_model_name(): return
try:
import transformers.models.gpt_oss.modeling_gpt_oss
transformers.models.gpt_oss.modeling_gpt_oss.GptOssModel
Expand All @@ -2075,12 +2171,25 @@ def patch_GptOssModel():
import transformers.generation.utils
def wrap(f):
def return_attention_mask(*args, **kwargs):
if kwargs["input_embeds"].requires_grad:
input_embeds = kwargs.get("input_embeds", None)
if input_embeds is None:
input_embeds = kwargs.get("inputs_embeds", None)
if input_embeds is None:
for arg in args:
if type(arg) is torch.Tensor and arg.is_floating_point():
input_embeds = arg
break

if input_embeds is not None and input_embeds.requires_grad:
if "attention_mask" in kwargs:
return kwargs["attention_mask"]
for arg in args:
if type(arg) is torch.Tensor and arg.dtype == torch.int32:
if (
type(arg) is torch.Tensor and
arg.dtype in (torch.int32, torch.int64, torch.bool)
):
return arg
return f(*args, **kwargs)
else:
# Eager
return f(*args, **kwargs)
Expand Down Expand Up @@ -2739,7 +2848,7 @@ def patch_gpt_oss_config():


def patch_gpt_oss_init_weights_modulelist_fix():
if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""):
if "gpt_oss" not in _normalized_unsloth_model_name():
return
try:
import transformers.models.gpt_oss.modeling_gpt_oss
Expand Down Expand Up @@ -2784,7 +2893,7 @@ def patch_gpt_oss_for_grpo():
When UNSLOTH_RETURN_HIDDEN_STATES=1, return hidden_states instead of logits.
This fixes the matrix multiplication dimension mismatch issue in GRPO training.
"""
if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""):
if "gpt_oss" not in _normalized_unsloth_model_name():
return

try:
Expand Down