Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions unsloth_zoo/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,7 +1000,7 @@ def apply_fused_lm_head(forward):

cross_entropy_replacement = cross_entropy_replacement\
.replace(
"$KWARGS$",
"$KWARGS$",
"locals().get('loss_kwargs', {}) or locals().get('kwargs', {})"
)

Expand Down Expand Up @@ -1179,7 +1179,7 @@ def patch_gradient_checkpointing(module, source):
.replace("LAYER", layer).replace("MODULELIST_ITEM", modulelist_item)\
.replace("ARGS", args).replace("$", spaces)
forward = forward.replace(forward[span[0] : span[1]], replacer)

# Also fix init
spaces = init.find("def")
init = init + "\n" + (spaces + 4) * " " + "self.gradient_checkpointing = False\n\n"
Expand Down Expand Up @@ -1381,10 +1381,10 @@ def patch_gradient_accumulation(modeling_file, module):

functions = dir(modeling_file)
module = eval(f"modeling_file.{module}")
try:
try:
forward = module.forward
source = inspect.getsource(forward)
except:
except:
return None
has_kwargs = tuple(inspect.signature(forward).parameters.values())[-1].kind == inspect._VAR_KEYWORD
if has_kwargs: return None
Expand Down Expand Up @@ -1450,6 +1450,10 @@ def unsloth_compile_transformers(
disable : bool = False,
return_logits : bool = False,
):
# import transformers logging module and instantiate model_type logging instance.
from transformers import logging as transformers_logging
model_logger = transformers_logging.get_logger(f"modeling_{model_type}")

# All Unsloth Zoo code licensed under LGPLv3
disable = disable or (os.environ.get("UNSLOTH_COMPILE_DISABLE", "0") == "1")
if fast_residual_stream:
Expand All @@ -1461,8 +1465,9 @@ def unsloth_compile_transformers(
modeling_file = eval(model_location)
if hasattr(modeling_file, "__UNSLOTH_PATCHED__"): return

# Remove `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`
exec("modeling_file.logger.addFilter(HideLoggingMessage('Setting `use_cache=False`'))", globals(), locals())
# Use transformers model_type logger to supress message: Remove `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`
exec("model_logger.addFilter(HideLoggingMessage('Setting `use_cache=False`'))", globals(), locals())


# torch_compile_options
UNSLOTH_COMPILE_DEBUG = os.environ.get("UNSLOTH_COMPILE_DEBUG", "0") == "1"
Expand Down Expand Up @@ -1792,7 +1797,7 @@ def unsloth_compile_transformers(
# Disable if torch < 2.5 or V100s 7.0 (Tesla T4 7.5 works) or old Triton < 3
if OLD_CUDA_ARCH_VERSION or OLD_TORCH_VERSION or OLD_TRITON_VERSION:
continue

module_class = eval(f"modeling_file.{module}")
if hasattr(module_class, "forward") and issubclass(module_class, GenerationMixin):
try:
Expand Down