Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
c39f56f
Fix mamba
danielhanchen Aug 9, 2025
4bd35c5
Update loader.py
danielhanchen Aug 9, 2025
1f0a4c3
Update vision.py
danielhanchen Aug 9, 2025
3cb9719
Update loader.py
danielhanchen Aug 9, 2025
a4081af
Merge branch 'main' into nightly
danielhanchen Aug 13, 2025
1432eac
Filter vLLM standby logs (#3131)
Datta0 Aug 13, 2025
fd1124a
Update loader.py
danielhanchen Aug 13, 2025
b78189b
Add scaler
danielhanchen Aug 13, 2025
cd2e284
Update llama.py
danielhanchen Aug 13, 2025
5e976a5
Update _utils.py
danielhanchen Aug 13, 2025
f451adf
Versioning
danielhanchen Aug 13, 2025
dafc7b8
Merge branch 'main' into nightly
danielhanchen Aug 13, 2025
bf5c402
Merge branch 'main' into nightly
danielhanchen Aug 13, 2025
3b82c42
GPT OSS fix
danielhanchen Aug 14, 2025
61366ef
GPT OSS fix
danielhanchen Aug 14, 2025
de043d9
Update loader.py
danielhanchen Aug 14, 2025
c1ef6f1
Update vision.py
danielhanchen Aug 14, 2025
f18cd26
Update vision.py
danielhanchen Aug 14, 2025
0215224
Update loader.py
danielhanchen Aug 14, 2025
5ed4a46
Update vision.py
danielhanchen Aug 15, 2025
e2ebb99
Merge branch 'main' into nightly
danielhanchen Aug 15, 2025
a222558
Update vision.py
danielhanchen Aug 15, 2025
cdcfe7d
Merge branch 'main' into nightly
danielhanchen Aug 15, 2025
6cffb1c
Update llama.py
danielhanchen Aug 15, 2025
15d33a5
Update llama.py
danielhanchen Aug 15, 2025
95a4daf
Update llama.py
danielhanchen Aug 15, 2025
4104bba
Versioning
danielhanchen Aug 15, 2025
8cc1999
Update mapper.py
danielhanchen Aug 15, 2025
a5dffd7
Merge branch 'main' into nightly
danielhanchen Aug 16, 2025
ffda8a7
Update vision.py
danielhanchen Aug 16, 2025
cdf2e17
Update vision.py
danielhanchen Aug 16, 2025
941d1ae
Update vision.py
danielhanchen Aug 16, 2025
73fa72c
Upcast norms
danielhanchen Aug 16, 2025
e4bbeef
Update loader.py
danielhanchen Aug 16, 2025
c8d00be
Update vision.py
danielhanchen Aug 16, 2025
564b6f8
Upcast layernorms
danielhanchen Aug 17, 2025
b8a34b4
Update llama.py
danielhanchen Aug 17, 2025
509fcb5
Update llama.py
danielhanchen Aug 17, 2025
27f1a2e
Update llama.py
danielhanchen Aug 18, 2025
931851a
Update llama.py
danielhanchen Aug 18, 2025
3b9057b
Update llama.py
danielhanchen Aug 18, 2025
3dd87bb
Update llama.py
danielhanchen Aug 18, 2025
f3f2b51
Merge branch 'main' into nightly
danielhanchen Aug 18, 2025
b757faf
Update save.py
danielhanchen Aug 18, 2025
2e86333
Update rl.py
danielhanchen Aug 18, 2025
b01e948
Update pyproject.toml
danielhanchen Aug 18, 2025
b064255
Merge branch 'main' into nightly
danielhanchen Aug 18, 2025
a751fd7
Update rl.py
danielhanchen Aug 18, 2025
c5d22e1
Merge branch 'main' into nightly
danielhanchen Aug 18, 2025
3cb6eaf
Update rl_replacements.py
danielhanchen Aug 18, 2025
de77a26
Update rl.py
danielhanchen Aug 19, 2025
27ca531
Update rl.py
danielhanchen Aug 19, 2025
6514c8e
Update rl.py
danielhanchen Aug 19, 2025
3e29ae7
Update _utils.py
danielhanchen Aug 19, 2025
a42f624
Update __init__.py
danielhanchen Aug 19, 2025
9437f9e
Torch 2.8
danielhanchen Aug 19, 2025
1dd99a2
Update rl_replacements.py
danielhanchen Aug 19, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 111 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,16 @@ cu126onlytorch260 = [
"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'",
"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'",
]
cu118onlytorch270 = [
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.30-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.30-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.30-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.30-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.30-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'",
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.30-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'",
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.30-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'",
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.30-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'",
]
cu126onlytorch270 = [
"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.30-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.30-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
Expand All @@ -227,6 +237,30 @@ cu128onlytorch270 = [
"xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.30-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'",
"xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.30-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'",
]
cu118onlytorch271 = [
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.31.post1-cp39-abi3-manylinux_2_28_x86_64.whl ; platform_system == 'Linux'",
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.31.post1-cp39-abi3-win_amd64.whl ; platform_system == 'Windows'",
]
cu126onlytorch271 = [
"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.31.post1-cp39-abi3-manylinux_2_28_x86_64.whl ; platform_system == 'Linux'",
"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.31.post1-cp39-abi3-win_amd64.whl ; platform_system == 'Windows'",
]
cu128onlytorch271 = [
"xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.31.post1-cp39-abi3-manylinux_2_28_x86_64.whl ; platform_system == 'Linux'",
"xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.31.post1-cp39-abi3-win_amd64.whl ; platform_system == 'Windows'",
]
cu118onlytorch280 = [
"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.32.post2-cp39-abi3-manylinux_2_28_x86_64.whl ; platform_system == 'Linux'",
"xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.32.post2-cp39-abi3-win_amd64.whl ; platform_system == 'Windows'",
]
cu126onlytorch280 = [
"xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.32.post2-cp39-abi3-manylinux_2_28_x86_64.whl ; platform_system == 'Linux'",
"xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.32.post2-cp39-abi3-win_amd64.whl ; platform_system == 'Windows'",
]
cu128onlytorch280 = [
"xformers @ https://download.pytorch.org/whl/cu129/xformers-0.0.32.post2-cp39-abi3-manylinux_2_28_x86_64.whl ; platform_system == 'Linux'",
"xformers @ https://download.pytorch.org/whl/cu129/xformers-0.0.32.post2-cp39-abi3-win_amd64.whl ; platform_system == 'Windows'",
]
cu118 = [
"unsloth[huggingface]",
"bitsandbytes>=0.45.5",
Expand Down Expand Up @@ -337,6 +371,11 @@ cu126-torch260 = [
"bitsandbytes>=0.45.5",
"unsloth[cu126onlytorch260]",
]
cu118-torch270 = [
"unsloth[huggingface]",
"bitsandbytes>=0.45.5",
"unsloth[cu118onlytorch270]",
]
cu126-torch270 = [
"unsloth[huggingface]",
"bitsandbytes>=0.45.5",
Expand All @@ -347,6 +386,36 @@ cu128-torch270 = [
"bitsandbytes>=0.45.5",
"unsloth[cu128onlytorch270]",
]
cu118-torch271 = [
"unsloth[huggingface]",
"bitsandbytes>=0.45.5",
"unsloth[cu118onlytorch271]",
]
cu126-torch271 = [
"unsloth[huggingface]",
"bitsandbytes>=0.45.5",
"unsloth[cu126onlytorch271]",
]
cu128-torch271 = [
"unsloth[huggingface]",
"bitsandbytes>=0.45.5",
"unsloth[cu128onlytorch271]",
]
cu118-torch280 = [
"unsloth[huggingface]",
"bitsandbytes>=0.45.5",
"unsloth[cu118onlytorch280]",
]
cu126-torch280 = [
"unsloth[huggingface]",
"bitsandbytes>=0.45.5",
"unsloth[cu126onlytorch280]",
]
cu128-torch280 = [
"unsloth[huggingface]",
"bitsandbytes>=0.45.5",
"unsloth[cu128onlytorch280]",
]
kaggle = [
"unsloth[huggingface]",
]
Expand Down Expand Up @@ -540,6 +609,12 @@ cu126-ampere-torch260 = [
"unsloth[cu126onlytorch260]",
"unsloth[flashattention]",
]
cu118-ampere-torch270 = [
"unsloth[huggingface]",
"bitsandbytes>=0.45.5",
"unsloth[cu118onlytorch270]",
"unsloth[flashattention]",
]
cu126-ampere-torch270 = [
"unsloth[huggingface]",
"bitsandbytes>=0.45.5",
Expand All @@ -552,7 +627,42 @@ cu128-ampere-torch270 = [
"unsloth[cu128onlytorch270]",
"unsloth[flashattention]",
]

cu118-ampere-torch271 = [
"unsloth[huggingface]",
"bitsandbytes>=0.45.5",
"unsloth[cu118onlytorch271]",
"unsloth[flashattention]",
]
cu126-ampere-torch271 = [
"unsloth[huggingface]",
"bitsandbytes>=0.45.5",
"unsloth[cu126onlytorch271]",
"unsloth[flashattention]",
]
cu128-ampere-torch271 = [
"unsloth[huggingface]",
"bitsandbytes>=0.45.5",
"unsloth[cu128onlytorch271]",
"unsloth[flashattention]",
]
cu118-ampere-torch280 = [
"unsloth[huggingface]",
"bitsandbytes>=0.45.5",
"unsloth[cu118onlytorch280]",
"unsloth[flashattention]",
]
cu126-ampere-torch280 = [
"unsloth[huggingface]",
"bitsandbytes>=0.45.5",
"unsloth[cu126onlytorch280]",
"unsloth[flashattention]",
]
cu128-ampere-torch280 = [
"unsloth[huggingface]",
"bitsandbytes>=0.45.5",
"unsloth[cu128onlytorch280]",
"unsloth[flashattention]",
]
flashattentiontorch260abiFALSEcu12x = [
"flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp39-cp39-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.9'",
"flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.10'",
Expand Down
25 changes: 25 additions & 0 deletions unsloth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.

try:
# Fix up AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'
# MUST do this at the start primarily due to tensorflow causing issues
import google.protobuf.message_factory
class MessageFactory:
def CreatePrototype(self, *args, **kwargs): return
def GetMessages(self, *args, **kwargs): return
def GetPrototype(self, *args, **kwargs): return
if not hasattr(google.protobuf.message_factory, "MessageFactory"):
google.protobuf.message_factory.MessageFactory = MessageFactory
elif hasattr(google.protobuf.message_factory, "MessageFactory") and \
not hasattr(google.protobuf.message_factory.MessageFactory, "GetPrototype") and \
not hasattr(google.protobuf.message_factory, "GetMessageClass"):
google.protobuf.message_factory.MessageFactory = MessageFactory
elif hasattr(google.protobuf.message_factory, "MessageFactory") and \
not hasattr(google.protobuf.message_factory.MessageFactory, "GetPrototype") and \
hasattr(google.protobuf.message_factory, "GetMessageClass"):
GetMessageClass = google.protobuf.message_factory.GetMessageClass
def GetPrototype(self, descriptor):
return GetMessageClass(descriptor)
google.protobuf.message_factory.MessageFactory.GetPrototype = GetPrototype
pass
except:
pass

import warnings, importlib, sys
from packaging.version import Version
import os, re, subprocess, inspect
Expand Down
6 changes: 5 additions & 1 deletion unsloth/_auto_install.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@
elif v < V('2.5.1'): x = 'cu{}{}-torch250'
elif v <= V('2.5.1'): x = 'cu{}{}-torch251'
elif v < V('2.7.0'): x = 'cu{}{}-torch260'
elif v < V('2.8.0'): x = 'cu{}{}-torch270'
elif v < V('2.7.9'): x = 'cu{}{}-torch270'
elif v < V('2.8.0'): x = 'cu{}{}-torch271'
elif v < V('2.8.9'): x = 'cu{}{}-torch280'
else: raise RuntimeError(f"Torch = {v} too new!")
if v > V('2.6.9') and cuda not in ("11.8", "12.6", "12.8"):
raise RuntimeError(f"CUDA = {cuda} not supported!")
x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
print(f'pip install --upgrade pip && pip install "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git"')
32 changes: 32 additions & 0 deletions unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,38 @@ def filter(self, x): return not (self.text in x.getMessage())
except:
pass

# Using a slow image processor as `use_fast`
try:
from transformers.processing_utils import logger as processing_utils_logger
processing_utils_logger.addFilter(HideLoggingMessage("`use_fast`"))
del processing_utils_logger
except:
pass

# Using a slow image processor as `use_fast`
try:
from transformers.models.auto.image_processing_auto import logger as processing_utils_logger
processing_utils_logger.addFilter(HideLoggingMessage("`use_fast`"))
del processing_utils_logger
except:
pass

# `use_cache=True` is incompatible with gradient checkpointing
try:
from transformers.trainer import logger as trainer_logger
trainer_logger.addFilter(HideLoggingMessage("`use_cache=True`"))
del trainer_logger
except:
pass

# `use_cache=True` is incompatible with gradient checkpointing
try:
from transformers.utils.generic import logger as trainer_logger
trainer_logger.addFilter(HideLoggingMessage("`use_cache=True`"))
del trainer_logger
except:
pass

# Errors out on
# Some weights of Gemma3nForConditionalGeneration were not initialized from the model checkpoint
from transformers.modeling_utils import logger as transformers_logger
Expand Down
24 changes: 22 additions & 2 deletions unsloth/models/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,18 @@ class Unsloth{RLConfig_name}({RLConfig_name}):
default = -1,
metadata = {{'help': 'Chunk size to reduce memory usage. -1 is most efficient.'}},
)
{max_seq_length_pre}
def __init__({RLConfig_arguments},
vllm_sampling_params = None,
unsloth_num_chunks = -1,
{max_seq_length_call}
**kwargs,
):
{RLConfig_extra_args}
super().__init__({RLConfig_call_args}{RLConfig_kwargs})
self.vllm_sampling_params = vllm_sampling_params
self.unsloth_num_chunks = unsloth_num_chunks
{max_seq_length_post}
pass

{RLTrainer_extras}
Expand Down Expand Up @@ -353,9 +356,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
" max_length = args.max_length\n"\
" else:\n"\
" model_max_length = getattr(model, 'max_seq_length', None)\n"\
" # print(model_max_length, 'mml1')\n"\
" if model_max_length is None: model_max_length = getattr(model, 'max_length', None)\n"\
" # print(model_max_length, 'mml2')\n"\
" if model_max_length is not None:\n"\
" args.max_length = model_max_length\n"\
" max_length = args.max_length\n"\
Expand Down Expand Up @@ -535,6 +536,21 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
extra_args += learning_rate_check
pass

# Check if max_seq_length is NOT defined (max_length is now default)
if "max_seq_length" not in call_args and "max_length" in call_args:
max_seq_length_pre = \
"""max_seq_length : Optional[int] = field(
default = None,
metadata = {'help': 'Maximum sequence length to truncate to.'},
)"""
max_seq_length_call = "max_seq_length = max_seq_length,"
max_seq_length_post = "self.max_seq_length = max_seq_length"
else:
max_seq_length_pre = ""
max_seq_length_call = ""
max_seq_length_post = ""
pass

# Add output_dir saving
if "output_dir" in call_args:
# Default checks
Expand Down Expand Up @@ -666,6 +682,10 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
RLTrainer_post = RLTrainer_post,
RL_pre = RL_pre,

max_seq_length_pre = max_seq_length_pre,
max_seq_length_call = max_seq_length_call,
max_seq_length_post = max_seq_length_post,

selective_log_softmax_code = selective_log_softmax_code,
)

Expand Down