Skip to content
Merged

Nightly #3720

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
e53e185
Update _utils.py
danielhanchen Dec 9, 2025
150eadf
Merge branch 'main' into nightly
danielhanchen Dec 9, 2025
25a6250
Merge branch 'main' into nightly
danielhanchen Dec 9, 2025
6b908cf
Merge branch 'main' into nightly
danielhanchen Dec 10, 2025
f754bd2
Merge branch 'main' into nightly
danielhanchen Dec 10, 2025
30ade52
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 10, 2025
94bbcaa
Merge branch 'main' into nightly
danielhanchen Dec 10, 2025
30b78c5
Merge branch 'main' into nightly
danielhanchen Dec 10, 2025
f357fc5
Merge branch 'main' into nightly
danielhanchen Dec 10, 2025
cde2d42
[FIX] [Transformers] VLM input embeds fix for gradients (#3715)
Datta0 Dec 12, 2025
a63a337
Merge branch 'main' into nightly
danielhanchen Dec 12, 2025
2c22ce6
Update rope_embedding.py
danielhanchen Dec 12, 2025
449430d
Merge branch 'main' into nightly
danielhanchen Dec 12, 2025
b5f1a77
Fixes
danielhanchen Dec 12, 2025
c94f595
Update _utils.py
danielhanchen Dec 12, 2025
01319d3
Update import_fixes.py
danielhanchen Dec 12, 2025
696a540
Update rl_replacements.py
danielhanchen Dec 12, 2025
ac54d6e
fix_openenv_no_vllm
danielhanchen Dec 12, 2025
9a98139
Fix
danielhanchen Dec 12, 2025
680f19f
Update __init__.py
danielhanchen Dec 12, 2025
fb763f3
Update __init__.py
danielhanchen Dec 12, 2025
f4f2a7f
Update __init__.py
danielhanchen Dec 12, 2025
e17b62f
Update import_fixes.py
danielhanchen Dec 12, 2025
f34eb0a
Update import_fixes.py
danielhanchen Dec 12, 2025
04ad21c
Update import_fixes.py
danielhanchen Dec 12, 2025
32b52a0
logger
danielhanchen Dec 12, 2025
0c9288f
Update __init__.py
danielhanchen Dec 12, 2025
b5b57b3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 12, 2025
efb5801
Update __init__.py
danielhanchen Dec 12, 2025
12ccc47
Merge branch 'nightly' of https://github.com/unslothai/unsloth into n…
danielhanchen Dec 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ huggingfacenotorch = [
]
huggingface = [
"unsloth[huggingfacenotorch]",
"unsloth_zoo>=2025.12.3",
"unsloth_zoo>=2025.12.4",
"torchvision",
"unsloth[triton]",
]
Expand Down Expand Up @@ -523,7 +523,7 @@ colab-ampere-torch220 = [
"flash-attn>=2.6.3 ; ('linux' in sys_platform)",
]
colab-new = [
"unsloth_zoo>=2025.12.3",
"unsloth_zoo>=2025.12.4",
"packaging",
"tyro",
"transformers>=4.51.3,!=4.52.0,!=4.52.1,!=4.52.2,!=4.52.3,!=4.53.0,!=4.54.0,!=4.55.0,!=4.55.1,!=4.57.0,<=4.57.3",
Expand Down
23 changes: 15 additions & 8 deletions unsloth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@
import os, re, subprocess, inspect, functools
import numpy as np

# Log Unsloth is being used
os.environ["UNSLOTH_IS_PRESENT"] = "1"

# Check if modules that need patching are already imported
critical_modules = ["trl", "transformers", "peft"]
already_imported = [mod for mod in critical_modules if mod in sys.modules]

# Fix some issues before importing other packages
from .import_fixes import (
fix_message_factory_issue,
Expand All @@ -31,10 +38,6 @@
del check_fbgemm_gpu_version
del torchvision_compatibility_check

# Check if modules that need patching are already imported
critical_modules = ["trl", "transformers", "peft"]
already_imported = [mod for mod in critical_modules if mod in sys.modules]

# This check is critical because Unsloth optimizes these libraries by modifying
# their code at import time. If they're imported first, the original (slower,
# more memory-intensive) implementations will be used instead of Unsloth's
Expand All @@ -43,7 +46,7 @@
# stacklevel=2 makes warning point to user's import line rather than this library code,
# showing them exactly where to fix the import order in their script
warnings.warn(
f"WARNING: Unsloth should be imported before {', '.join(already_imported)} "
f"WARNING: Unsloth should be imported before [{', '.join(already_imported)}] "
f"to ensure all optimizations are applied. Your code may run slower or encounter "
f"memory issues without these optimizations.\n\n"
f"Please restructure your imports with 'import unsloth' at the top of your file.",
Expand All @@ -63,16 +66,14 @@
# "pinned_use_cuda_host_register:True,"\
# "pinned_num_register_threads:8"

# Log Unsloth is being used
os.environ["UNSLOTH_IS_PRESENT"] = "1"

from importlib.metadata import version as importlib_version
from importlib.metadata import PackageNotFoundError

# Check for unsloth_zoo
try:
unsloth_zoo_version = importlib_version("unsloth_zoo")
if Version(unsloth_zoo_version) < Version("2025.12.3"):
if Version(unsloth_zoo_version) < Version("2025.12.4"):
print(
"Unsloth: Please update Unsloth and Unsloth-Zoo to the latest version!\n"
"Do this via `pip install --upgrade --force-reinstall --no-cache-dir --no-deps unsloth unsloth_zoo`"
Expand Down Expand Up @@ -123,6 +124,8 @@
patch_ipykernel_hf_xet,
patch_trackio,
patch_datasets,
patch_enable_input_require_grads,
fix_openenv_no_vllm,
)

fix_xformers_performance_issue()
Expand All @@ -132,6 +135,8 @@
patch_ipykernel_hf_xet()
patch_trackio()
patch_datasets()
patch_enable_input_require_grads()
fix_openenv_no_vllm()

del fix_xformers_performance_issue
del fix_vllm_aimv2_issue
Expand All @@ -140,6 +145,8 @@
del patch_ipykernel_hf_xet
del patch_trackio
del patch_datasets
del patch_enable_input_require_grads
del fix_openenv_no_vllm

# Torch 2.4 has including_emulation
if DEVICE_TYPE == "cuda":
Expand Down
159 changes: 133 additions & 26 deletions unsloth/import_fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from packaging.version import Version as TrueVersion
import re
import logging

UNSLOTH_ENABLE_LOGGING = os.environ.get("UNSLOTH_ENABLE_LOGGING", "0") == "1"
# Cannot import logger here since it'll import transformers
# from unsloth_zoo.log import logger


def Version(version):
Expand Down Expand Up @@ -70,9 +70,10 @@ def GetMessages(self, *args, **kwargs):
def GetPrototype(self, *args, **kwargs):
return

from unsloth_zoo.log import logger

if not hasattr(google.protobuf.message_factory, "MessageFactory"):
if UNSLOTH_ENABLE_LOGGING:
print("Unsloth: Patching protobuf.MessageFactory as it doesn't exist")
logger.info("Unsloth: Patching protobuf.MessageFactory as it doesn't exist")
google.protobuf.message_factory.MessageFactory = MessageFactory
elif (
hasattr(google.protobuf.message_factory, "MessageFactory")
Expand All @@ -82,8 +83,7 @@ def GetPrototype(self, *args, **kwargs):
and not hasattr(google.protobuf.message_factory, "GetMessageClass")
):
google.protobuf.message_factory.MessageFactory = MessageFactory
if UNSLOTH_ENABLE_LOGGING:
print("Unsloth: Patching protobuf.MessageFactory as it doesn't exist")
logger.info("Unsloth: Patching protobuf.MessageFactory as it doesn't exist")
elif (
hasattr(google.protobuf.message_factory, "MessageFactory")
and not hasattr(
Expand All @@ -97,8 +97,7 @@ def GetPrototype(self, descriptor):
return GetMessageClass(descriptor)

google.protobuf.message_factory.MessageFactory.GetPrototype = GetPrototype
if UNSLOTH_ENABLE_LOGGING:
print("Unsloth: Patching protobuf.MessageFactory.GetPrototype")
logger.info("Unsloth: Patching protobuf.MessageFactory.GetPrototype")
pass
except:
pass
Expand All @@ -110,6 +109,8 @@ def fix_xformers_performance_issue():
return
xformers_version = importlib_version("xformers")
if Version(xformers_version) < Version("0.0.29"):
from unsloth_zoo.log import logger

xformers_location = importlib.util.find_spec("xformers").origin
xformers_location = os.path.split(xformers_location)[0]
cutlass = Path(xformers_location) / "ops" / "fmha" / "cutlass.py"
Expand All @@ -126,13 +127,11 @@ def fix_xformers_performance_issue():
f.seek(0)
f.write(text)
f.truncate()
if UNSLOTH_ENABLE_LOGGING:
print(
"Unsloth: Patching Xformers to fix some performance issues."
)
logger.info(
"Unsloth: Patching Xformers to fix some performance issues."
)
except Exception as e:
if UNSLOTH_ENABLE_LOGGING:
print(f"Unsloth: Failed patching Xformers with error = {str(e)}")
logger.info(f"Unsloth: Failed patching Xformers with error = {str(e)}")


# ValueError: 'aimv2' is already used by a Transformers config, pick another name.
Expand All @@ -141,6 +140,8 @@ def fix_vllm_aimv2_issue():
return
vllm_version = importlib_version("vllm")
if Version(vllm_version) < Version("0.10.1"):
from unsloth_zoo.log import logger

vllm_version = importlib.util.find_spec("vllm").origin
vllm_version = os.path.split(vllm_version)[0]
ovis_config = Path(vllm_version) / "transformers_utils" / "configs" / "ovis.py"
Expand All @@ -167,13 +168,11 @@ def fix_vllm_aimv2_issue():
f.seek(0)
f.write(text)
f.truncate()
if UNSLOTH_ENABLE_LOGGING:
print(
"Unsloth: Patching vLLM to fix `'aimv2' is already used by a Transformers config, pick another name.`"
)
logger.info(
"Unsloth: Patching vLLM to fix `'aimv2' is already used by a Transformers config, pick another name.`"
)
except Exception as e:
if UNSLOTH_ENABLE_LOGGING:
print(f"Unsloth: Failed patching vLLM with error = {str(e)}")
logger.info(f"Unsloth: Failed patching vLLM with error = {str(e)}")


def fix_vllm_guided_decoding_params():
Expand Down Expand Up @@ -274,8 +273,74 @@ def check_fbgemm_gpu_version():
raise ImportError(
f"Unsloth: fbgemm_gpu_genai=={fbgemm_gpu_version} detected. It might cause unexpected issues like segmentation faults. Please uninstall the current one by doing `pip uninstall fbgemm-gpu` && `pip install fbgemm-gpu` to install fbgemm-gpu 1.4.0 or newer!"
)
elif UNSLOTH_ENABLE_LOGGING:
print(f"Unsloth: fbgemm_gpu_genai=={fbgemm_gpu_version} detected.")
from unsloth_zoo.log import logger

logger.info(f"Unsloth: fbgemm_gpu_genai=={fbgemm_gpu_version} detected.")


def patch_enable_input_require_grads():
"""
Patch transformers PreTrainedModel.enable_input_require_grads to handle vision models
that raise NotImplementedError from get_input_embeddings().

"""
import inspect
from transformers import PreTrainedModel

# Check if the original function iterates over self.modules() instead of just returning the enable_input_require_grads
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a bare except: is generally discouraged as it can catch unexpected exceptions like SystemExit or KeyboardInterrupt, making it harder to debug and interrupt the program. It's better to catch specific exceptions, or Exception at the very least.

Suggested change
# Check if the original function iterates over self.modules() instead of just returning the enable_input_require_grads
except Exception:

# Ref: https://github.com/huggingface/transformers/pull/41993/files#diff-6b72b98c4c2dcfc6cc606843917733f5d858374fbc22a735ff483bbc0c1e63eaL1979-R1996
try:
original_source = inspect.getsource(PreTrainedModel.enable_input_require_grads)
except:
return

# Only patch if the new pattern exists (iterating over self.modules())
if "for module in self.modules()" not in original_source:
return

def _patched_enable_input_require_grads(self):
def make_inputs_require_grads(module, input, output):
output.requires_grad_(True)

hooks = []
seen_modules = set()

for module in self.modules():
if not (
isinstance(module, PreTrainedModel)
and hasattr(module, "get_input_embeddings")
):
continue

try:
input_embeddings = module.get_input_embeddings()
except NotImplementedError:
# Vision models may not implement get_input_embeddings - skip them
# For GLM V4.6 for example, this skips only `self.visual`
continue

if input_embeddings is None:
continue

embedding_id = id(input_embeddings)
if embedding_id in seen_modules:
continue

seen_modules.add(embedding_id)
hooks.append(
input_embeddings.register_forward_hook(make_inputs_require_grads)
)

self._require_grads_hooks = hooks
if hooks:
self._require_grads_hook = hooks[0]

PreTrainedModel.enable_input_require_grads = _patched_enable_input_require_grads
from unsloth_zoo.log import logger

logger.info(
"Unsloth: Patched enable_input_require_grads for vision model compatibility"
)


def torchvision_compatibility_check():
Expand Down Expand Up @@ -313,7 +378,49 @@ def torchvision_compatibility_check():
f"but found torchvision=={torchvision_version}. "
f"Please refer to https://pytorch.org/get-started/previous-versions/ for more information."
)
elif UNSLOTH_ENABLE_LOGGING:
print(
f"Unsloth: torch=={torch_version} and torchvision=={torchvision_version} are compatible."
)
from unsloth_zoo.log import logger

logger.info(
f"Unsloth: torch=={torch_version} and torchvision=={torchvision_version} are compatible."
)


# Fix TRL OpenEnv 0.26 NameError: name 'SamplingParams' is not defined
def fix_openenv_no_vllm():
if importlib.util.find_spec("trl") is None:
return
trl_location = importlib.util.find_spec("trl").origin
trl_location = os.path.split(trl_location)[0]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's a minor style issue here. According to PEP 8, there should be no spaces around the = sign for keyword arguments. It should be encoding="utf-8".

Suggested change
trl_location = os.path.split(trl_location)[0]
with open(openenv, "r+", encoding="utf-8") as f:

openenv = Path(trl_location) / "experimental" / "openenv" / "utils.py"
if not openenv.exists():
return
from unsloth_zoo.log import logger

try:
with open(openenv, "r+", encoding = "utf-8") as f:
text = f.read()
bad = (
"if is_vllm_available():\n"
" from vllm import SamplingParams\n"
" from vllm.sampling_params import GuidedDecodingParams\n"
)
if bad + "\n" + "\n" in text:
text = text.replace(
bad + "\n" + "\n",
bad
+ (
"else:\n"
" from typing import Any\n"
" SamplingParams = Any\n"
" GuidedDecodingParams = Any\n"
"\n"
),
)
f.seek(0)
f.write(text)
f.truncate()
logger.info(
"Unsloth: Patching TRL OpenEnv to fix SamplingParams not defined"
)
except Exception as e:
logger.info(f"Unsloth: Failed patching TRL OpenEnv with error = {str(e)}")
16 changes: 9 additions & 7 deletions unsloth/kernels/rope_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,6 @@
from .utils import calculate_settings, torch_gpu_device, torch_device_stream


@triton.heuristics(
{
"BACKWARD_PASS": lambda args: bool(args["BACKWARD_PASS"]),
"HAS_ROPE_INDICES": lambda args: bool(args["HAS_ROPE_INDICES"]),
}
)
@triton.jit
def _rope_embedding_QK(
Q,
Q_batch_stride,
Expand Down Expand Up @@ -104,6 +97,15 @@ def _rope_embedding_QK(
tl.store(k_ptr + half_head_dim + col_offsets, k1 * cos1 + k0 * sin1, mask = mask)


_rope_embedding_QK = triton.jit(_rope_embedding_QK)
_rope_embedding_QK = triton.heuristics(
{
"BACKWARD_PASS": lambda args: bool(args["BACKWARD_PASS"]),
"HAS_ROPE_INDICES": lambda args: bool(args["HAS_ROPE_INDICES"]),
}
)(_rope_embedding_QK)


ROPE_GROUP_SIZE: int = 4


Expand Down
12 changes: 11 additions & 1 deletion unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "2025.12.4"
__version__ = "2025.12.5"

__all__ = [
"SUPPORTS_BFLOAT16",
Expand Down Expand Up @@ -413,6 +413,16 @@ def filter(self, x):
except:
pass

# Flax classes are deprecated and will be removed in Diffusers v1.0.0.
try:
from diffusers.utils import logger as diffusers_logger

diffusers_logger.addFilter(HideLoggingMessage("are deprecated"))
del diffusers_logger
except:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a bare except: is generally discouraged as it can catch unexpected exceptions like SystemExit or KeyboardInterrupt, making it harder to debug and interrupt the program. It's better to catch specific exceptions, or Exception at the very least.

Suggested change
except:
except Exception:

pass


# Errors out on
# Some weights of Gemma3nForConditionalGeneration were not initialized from the model checkpoint
from transformers.modeling_utils import logger as transformers_logger
Expand Down
Loading