Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
756 commits
Select commit Hold shift + click to select a range
b8b0f9c
_wrap_fast_inference
danielhanchen Mar 5, 2025
6f0857b
Update llama.py
danielhanchen Mar 5, 2025
109364b
Update llama.py
danielhanchen Mar 5, 2025
dd4bd07
Update llama.py
danielhanchen Mar 5, 2025
b356fce
Update llama.py
danielhanchen Mar 5, 2025
e022016
Update llama.py
danielhanchen Mar 5, 2025
12094a7
Update llama.py
danielhanchen Mar 5, 2025
2836128
Update llama.py
danielhanchen Mar 5, 2025
c956616
Update llama.py
danielhanchen Mar 5, 2025
e887f43
Update llama.py
danielhanchen Mar 5, 2025
95f872d
Update llama.py
danielhanchen Mar 5, 2025
647dbb4
Update llama.py
danielhanchen Mar 5, 2025
f640c8d
Update _utils.py
danielhanchen Mar 5, 2025
91a4fce
SFT dataset prepare
danielhanchen Mar 5, 2025
4495148
Update pyproject.toml
danielhanchen Mar 5, 2025
f41dff5
Update rl_replacements.py
danielhanchen Mar 5, 2025
0a3dbfa
Update rl_replacements.py
danielhanchen Mar 5, 2025
7d8f100
Update rl_replacements.py
danielhanchen Mar 5, 2025
413ea80
Update rl.py
danielhanchen Mar 5, 2025
3f5ce93
Update llama.py
danielhanchen Mar 5, 2025
185bced
Update llama.py
danielhanchen Mar 5, 2025
fd11ad7
Update utils.py
danielhanchen Mar 5, 2025
97ed0b4
bug fix
danielhanchen Mar 5, 2025
68eca88
Update llama.py
danielhanchen Mar 5, 2025
5daf9b5
Update llama.py
danielhanchen Mar 5, 2025
858bb76
Update llama.py
danielhanchen Mar 5, 2025
daedc34
Update llama.py
danielhanchen Mar 5, 2025
95e2371
Update llama.py
danielhanchen Mar 5, 2025
fccd68a
Update __init__.py
danielhanchen Mar 5, 2025
c665e0b
Update _utils.py
danielhanchen Mar 6, 2025
d207daf
Merge branch 'main' into nightly
danielhanchen Mar 6, 2025
dbf7eac
Update _utils.py
danielhanchen Mar 6, 2025
b55f6d9
Update _utils.py
danielhanchen Mar 6, 2025
c7abf7d
Update _utils.py
danielhanchen Mar 6, 2025
98d5ab0
Update _utils.py
danielhanchen Mar 6, 2025
f72794e
Update rl.py
danielhanchen Mar 6, 2025
1ec0ee2
Update rl.py
danielhanchen Mar 6, 2025
5350c6a
Update rl.py
danielhanchen Mar 6, 2025
9009ef0
Update _utils.py
danielhanchen Mar 6, 2025
7f7899d
Update __init__.py
danielhanchen Mar 6, 2025
334bd77
Update _utils.py
danielhanchen Mar 6, 2025
ade31e2
Version
danielhanchen Mar 6, 2025
a31e45b
Merge branch 'main' into nightly
danielhanchen Mar 6, 2025
8015ff2
versioning
danielhanchen Mar 6, 2025
d8777be
Update _utils.py
danielhanchen Mar 6, 2025
132b838
Update llama.py
danielhanchen Mar 6, 2025
21faa50
Update llama.py
danielhanchen Mar 6, 2025
af5d875
Merge branch 'main' into nightly
danielhanchen Mar 6, 2025
904e1c5
Bug fixes
danielhanchen Mar 7, 2025
761bb8f
FastModel
danielhanchen Mar 8, 2025
7bf880f
__doc__
danielhanchen Mar 8, 2025
c93b51b
Update vision.py
danielhanchen Mar 8, 2025
f8867be
Update loader.py
danielhanchen Mar 8, 2025
2ab1828
Update loader.py
danielhanchen Mar 8, 2025
e05baed
Update loader.py
danielhanchen Mar 8, 2025
31012a7
version
danielhanchen Mar 8, 2025
a8bf659
Merge branch 'main' into nightly
danielhanchen Mar 9, 2025
d72e3e0
move use_modelscope to _utils (#1938)
KareemMusleh Mar 9, 2025
7e82339
Don't use revision when loading model_config and is_peft=True (#1949)
wiwu2390 Mar 9, 2025
4904c48
More syntax warnings (#1944)
KareemMusleh Mar 9, 2025
7aaa605
Update loader.py
danielhanchen Mar 9, 2025
a585536
Full finetuning and other fixes
danielhanchen Mar 10, 2025
133c0ae
UNSLOTH_ENABLE_FULL_FINETUNING
danielhanchen Mar 10, 2025
9d5aa5c
Update loader.py
danielhanchen Mar 10, 2025
934ad16
Update loader.py
danielhanchen Mar 10, 2025
76f2f2a
Update loader.py
danielhanchen Mar 10, 2025
f763ed6
Update vision.py
danielhanchen Mar 10, 2025
0df9518
Update vision.py
danielhanchen Mar 10, 2025
ced164e
full finetuning
danielhanchen Mar 10, 2025
5b45f0f
Update loader.py
danielhanchen Mar 10, 2025
23d45cf
Update loader.py
danielhanchen Mar 10, 2025
bdebea7
Update loader.py
danielhanchen Mar 10, 2025
04f1abc
Update _utils.py
danielhanchen Mar 10, 2025
4c0a8d6
max_seq_length
danielhanchen Mar 10, 2025
8f16ce0
Update rl.py
danielhanchen Mar 10, 2025
8b16a16
Update rl.py
danielhanchen Mar 10, 2025
a8c96d3
Update rl.py
danielhanchen Mar 10, 2025
739b1dd
Update pyproject.toml
danielhanchen Mar 11, 2025
c555388
AutoModelForImageTextToText
danielhanchen Mar 11, 2025
77fec99
Update mapper.py
danielhanchen Mar 11, 2025
c539fc6
Update pyproject.toml
danielhanchen Mar 11, 2025
3ddcf84
Update _utils.py
danielhanchen Mar 11, 2025
3aa2d95
Update _utils.py
danielhanchen Mar 11, 2025
a3541c0
Update _utils.py
danielhanchen Mar 11, 2025
a4faf0f
Batch samples
danielhanchen Mar 12, 2025
eb0add4
Update loader.py
danielhanchen Mar 12, 2025
b556785
Update loader.py
danielhanchen Mar 12, 2025
ead1b3b
Update loader.py
danielhanchen Mar 12, 2025
b388d8d
Update loader.py
danielhanchen Mar 12, 2025
80eac80
Update _utils.py
danielhanchen Mar 12, 2025
d6d862e
Update loader.py
danielhanchen Mar 12, 2025
ea6aae6
Update vision.py
danielhanchen Mar 12, 2025
0c4ebb3
Update loader.py
danielhanchen Mar 12, 2025
528e8f0
Update vision.py
danielhanchen Mar 12, 2025
152b376
Update vision.py
danielhanchen Mar 12, 2025
2fdeecd
Update vision.py
danielhanchen Mar 12, 2025
ceda772
Update mapper.py
danielhanchen Mar 12, 2025
0df6ad4
Merge branch 'main' into nightly
danielhanchen Mar 12, 2025
f386f0f
Update vision.py
danielhanchen Mar 12, 2025
b6187c6
Temporary patches
danielhanchen Mar 13, 2025
bb59cec
Update loader.py
danielhanchen Mar 13, 2025
3326c4f
model names
danielhanchen Mar 13, 2025
bb193e4
Gemma 3 chat template
danielhanchen Mar 13, 2025
57a5442
Bug fixes
danielhanchen Mar 13, 2025
8457c75
Update vision.py
danielhanchen Mar 13, 2025
bc735a7
Update vision.py
danielhanchen Mar 13, 2025
ed588ee
Update vision.py
danielhanchen Mar 13, 2025
a3637fa
Update vision.py
danielhanchen Mar 13, 2025
6218eae
Update vision.py
danielhanchen Mar 13, 2025
9005a57
Update llama.py
danielhanchen Mar 13, 2025
97f40bd
Update llama.py
danielhanchen Mar 13, 2025
24cd9f7
Update rl.py
danielhanchen Mar 13, 2025
b0d9ee0
Update chat_templates.py
danielhanchen Mar 13, 2025
07f47a4
Update chat_templates.py
danielhanchen Mar 13, 2025
caec8ff
Update vision.py
danielhanchen Mar 13, 2025
c96eab5
Update vision.py
danielhanchen Mar 13, 2025
6e58d97
Update vision.py
danielhanchen Mar 13, 2025
dd17676
Update loader.py
danielhanchen Mar 13, 2025
7d0893b
Update vision.py
danielhanchen Mar 13, 2025
8b51a7d
Update vision.py
danielhanchen Mar 13, 2025
833e295
Revert
danielhanchen Mar 13, 2025
20ae25a
Update _utils.py
danielhanchen Mar 13, 2025
067fb5e
forced precision
danielhanchen Mar 13, 2025
7493af8
Autocast
danielhanchen Mar 13, 2025
6dcd0bf
Update vision.py
danielhanchen Mar 13, 2025
c6eae35
Update vision.py
danielhanchen Mar 13, 2025
d1f09cf
Update rl.py
danielhanchen Mar 13, 2025
e0e31d9
Update vision.py
danielhanchen Mar 13, 2025
57576a5
Update vision.py
danielhanchen Mar 13, 2025
3b6c379
Update vision.py
danielhanchen Mar 13, 2025
b284ed5
Update vision.py
danielhanchen Mar 13, 2025
ed80c07
Update vision.py
danielhanchen Mar 13, 2025
171ad42
Update rl.py
danielhanchen Mar 13, 2025
9f6d280
vLLM fixes
danielhanchen Mar 14, 2025
f525442
constexpr
danielhanchen Mar 14, 2025
6e7d5be
Update vision.py
danielhanchen Mar 14, 2025
e388265
Update vision.py
danielhanchen Mar 14, 2025
2def2a5
Update vision.py
danielhanchen Mar 14, 2025
69f4581
Update rl.py
danielhanchen Mar 14, 2025
13788ab
Update llama.py
danielhanchen Mar 14, 2025
7ccacc3
Update llama.py
danielhanchen Mar 14, 2025
a219029
Update llama.py
danielhanchen Mar 14, 2025
d9d1116
Update llama.py
danielhanchen Mar 14, 2025
050cb85
Update llama.py
danielhanchen Mar 14, 2025
ae54a69
Update llama.py
danielhanchen Mar 14, 2025
5a4f410
Update llama.py
danielhanchen Mar 14, 2025
c21dba4
Update llama.py
danielhanchen Mar 14, 2025
1f7f78e
Update _utils.py
danielhanchen Mar 14, 2025
edd6181
Update _utils.py
danielhanchen Mar 14, 2025
6547468
Update _utils.py
danielhanchen Mar 14, 2025
7afe411
Update _utils.py
danielhanchen Mar 14, 2025
13b4a95
Update save.py
danielhanchen Mar 14, 2025
2b76350
New models
danielhanchen Mar 14, 2025
1b45ab6
Triton windows update (#1976)
Captain-T2004 Mar 14, 2025
6aaf377
Update RMS LayerNorm implementation, and list compr. change in chat t…
NinoRisteski Mar 14, 2025
94f075c
Update Zoo
danielhanchen Mar 14, 2025
1d6c395
Merge branch 'nightly' of https://github.com/unslothai/unsloth into n…
danielhanchen Mar 14, 2025
8ec6e8b
Merge branch 'main' into nightly
danielhanchen Mar 14, 2025
4ef899c
Update llama.py
danielhanchen Mar 14, 2025
9cd4f47
Update llama.py
danielhanchen Mar 14, 2025
5e17f22
Update vision.py
danielhanchen Mar 14, 2025
0003ead
Update vision.py
danielhanchen Mar 14, 2025
8f455fc
Update vision.py
danielhanchen Mar 14, 2025
790833e
Update vision.py
danielhanchen Mar 14, 2025
ba8408d
Update vision.py
danielhanchen Mar 14, 2025
e78fe39
Update vision.py
danielhanchen Mar 14, 2025
6b5eb3c
Update vision.py
danielhanchen Mar 14, 2025
9703843
Update vision.py
danielhanchen Mar 14, 2025
f6efd4d
Update vision.py
danielhanchen Mar 14, 2025
9bc273b
Update vision.py
danielhanchen Mar 14, 2025
26045d8
Update vision.py
danielhanchen Mar 14, 2025
f988ed4
Update vision.py
danielhanchen Mar 14, 2025
5d98f5b
Update rl_replacements.py
danielhanchen Mar 14, 2025
4079dba
Update vision.py
danielhanchen Mar 14, 2025
9554dd5
grpo fix
danielhanchen Mar 14, 2025
3a76607
Update rl_replacements.py
danielhanchen Mar 14, 2025
1d73f9e
Update vision.py
danielhanchen Mar 14, 2025
35383c3
Update rl_replacements.py
danielhanchen Mar 14, 2025
fc74d92
Update vision.py
danielhanchen Mar 14, 2025
3ac4fa5
Update mapper.py
danielhanchen Mar 14, 2025
b75698c
Update vision.py
danielhanchen Mar 14, 2025
87363a6
Update vision.py
danielhanchen Mar 14, 2025
1a17945
Update loader.py
danielhanchen Mar 14, 2025
e72a79a
Merge branch 'main' into nightly
danielhanchen Mar 14, 2025
21867b7
Update vision.py
danielhanchen Mar 14, 2025
a6e86f4
Update save.py
danielhanchen Mar 14, 2025
b9de6dc
Update save.py
danielhanchen Mar 14, 2025
3c3d9b3
Update save.py
danielhanchen Mar 14, 2025
2b8c15c
Merge branch 'main' into nightly
danielhanchen Mar 15, 2025
0f0e6eb
Update rl.py
danielhanchen Mar 16, 2025
8ab8c6c
Update _utils.py
danielhanchen Mar 16, 2025
48a33ad
Merge branch 'main' into nightly
danielhanchen Mar 16, 2025
e50fb74
Version
danielhanchen Mar 16, 2025
69659f6
Update pyproject.toml
danielhanchen Mar 16, 2025
ee07fb9
Update llama.py
danielhanchen Mar 16, 2025
cfa846e
Update llama.py
danielhanchen Mar 16, 2025
b1ec22d
bug fix #2008 (#2039)
void-mckenzie Mar 16, 2025
ce4558b
fix (#2051)
KareemMusleh Mar 16, 2025
97c2a88
Update loader.py
danielhanchen Mar 17, 2025
64c2918
Update pyproject.toml
danielhanchen Mar 17, 2025
60b3da5
Update pyproject.toml
danielhanchen Mar 17, 2025
19c6928
Update vision.py
danielhanchen Mar 17, 2025
f358b79
more prints
danielhanchen Mar 17, 2025
301f7fd
Update loader.py
danielhanchen Mar 17, 2025
df554bc
LoRA 16bit fix
danielhanchen Mar 17, 2025
82debd2
Update vision.py
danielhanchen Mar 17, 2025
682de74
Update vision.py
danielhanchen Mar 17, 2025
28b4128
Update _utils.py
danielhanchen Mar 17, 2025
6d596da
Update vision.py
danielhanchen Mar 17, 2025
9a356a7
move forced float32
danielhanchen Mar 17, 2025
9f55885
Update _utils.py
danielhanchen Mar 17, 2025
12de176
Update _utils.py
danielhanchen Mar 17, 2025
5ca4f5c
Update _utils.py
danielhanchen Mar 17, 2025
3cf8d07
Update _utils.py
danielhanchen Mar 17, 2025
78e85e3
move print
danielhanchen Mar 17, 2025
07ea763
Update _utils.py
danielhanchen Mar 17, 2025
0cf990f
disable bfloat16
danielhanchen Mar 18, 2025
d3eaf9e
Fix forced float32
danielhanchen Mar 18, 2025
984273a
move float32
danielhanchen Mar 18, 2025
457fc12
Ensure trust_remote_code propegates down to unsloth_compile_transform…
CuppaXanax Mar 18, 2025
997fa41
Update _utils.py
danielhanchen Mar 18, 2025
e2dc4b3
Merge branch 'nightly' of https://github.com/unslothai/unsloth into n…
danielhanchen Mar 18, 2025
420380d
Show both `peft_error` and `autoconfig_error`, not just `autoconfig_e…
IsaacBreen Mar 18, 2025
0e54be4
fix error message (#2046)
KareemMusleh Mar 18, 2025
4756979
Update vision.py
danielhanchen Mar 18, 2025
50c98b5
Update _utils.py
danielhanchen Mar 18, 2025
23bac1d
Update pyproject.toml
danielhanchen Mar 18, 2025
aed7d20
Update __init__.py
danielhanchen Mar 18, 2025
7fcda1a
Update __init__.py
danielhanchen Mar 18, 2025
f0de417
Update vision.py
danielhanchen Mar 18, 2025
2e377bc
Update vision.py
danielhanchen Mar 18, 2025
5d64bff
Update vision.py
danielhanchen Mar 18, 2025
c965c86
Update vision.py
danielhanchen Mar 18, 2025
d9e984e
Update vision.py
danielhanchen Mar 18, 2025
eb959ca
Update vision.py
danielhanchen Mar 18, 2025
0372df7
Update vision.py
danielhanchen Mar 18, 2025
d767920
Update vision.py
danielhanchen Mar 18, 2025
ea19392
Update vision.py
danielhanchen Mar 18, 2025
9486268
Update rl_replacements.py
danielhanchen Mar 18, 2025
e87368f
Update rl_replacements.py
danielhanchen Mar 18, 2025
2a620fc
Update rl_replacements.py
danielhanchen Mar 18, 2025
8d2885f
Update rl_replacements.py
danielhanchen Mar 18, 2025
b9e3455
Update vision.py
danielhanchen Mar 18, 2025
ce766f2
Update vision.py
danielhanchen Mar 18, 2025
beed394
Update vision.py
danielhanchen Mar 18, 2025
a09f3dc
Update vision.py
danielhanchen Mar 18, 2025
45377be
Update vision.py
danielhanchen Mar 18, 2025
558b052
Update rl_replacements.py
danielhanchen Mar 18, 2025
800a465
Update vision.py
danielhanchen Mar 18, 2025
8753a59
Update rl_replacements.py
danielhanchen Mar 18, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ triton = [
]

huggingface = [
"unsloth_zoo>=2025.3.11",
"unsloth_zoo>=2025.3.13",
"packaging",
"tyro",
"transformers>=4.46.1,!=4.47.0",
Expand Down Expand Up @@ -351,7 +351,7 @@ colab-ampere-torch220 = [
"flash-attn>=2.6.3",
]
colab-new = [
"unsloth_zoo>=2025.3.9",
"unsloth_zoo>=2025.3.13",
"packaging",
"tyro",
"transformers>=4.46.1,!=4.47.0",
Expand Down Expand Up @@ -511,4 +511,4 @@ cu126-ampere-torch260 = [
[project.urls]
homepage = "http://www.unsloth.ai"
documentation = "https://github.com/unslothai/unsloth"
repository = "https://github.com/unslothai/unsloth"
repository = "https://github.com/unslothai/unsloth"
4 changes: 2 additions & 2 deletions unsloth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,10 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16
# Check for unsloth_zoo
try:
unsloth_zoo_version = importlib_version("unsloth_zoo")
if Version(unsloth_zoo_version) < Version("2025.3.11"):
if Version(unsloth_zoo_version) < Version("2025.3.13"):
print(
"Unsloth: Updating Unsloth-Zoo utilies to the latest version.\n"\
"To disable this, set os.environ['UNSLOTH_DISABLE_AUTO_UPDATES'] = '1'"
"To disable this, set `os.environ['UNSLOTH_DISABLE_AUTO_UPDATES'] = '1'`"
)
if os.environ.get("UNSLOTH_DISABLE_AUTO_UPDATES", "0") == "0":
try:
Expand Down
39 changes: 22 additions & 17 deletions unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "2025.3.14"
__version__ = "2025.3.15"

__all__ = [
"SUPPORTS_BFLOAT16",
Expand Down Expand Up @@ -182,6 +182,15 @@ def filter(self, x): return not (self.text in x.getMessage())
except:
pass

# Gemma3 It is strongly recommended to train Gemma3 models with the `eager`
try:
from transformers.models.gemma3.modeling_gemma3 import logger as gemma3_logger
gemma3_logger.addFilter(HideLoggingMessage("strongly recommended"))
del gemma3_logger
except:
pass


# Patch get_model_param_count to record correct 4bit / 8bit
from transformers.trainer_pt_utils import is_deepspeed_zero3_enabled
def get_model_param_count(model, trainable_only = False):
Expand Down Expand Up @@ -1016,13 +1025,7 @@ def _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs):
"Read more on gradient accumulation issues here: https://unsloth.ai/blog/gradient"
)
pass

if os.environ.get("UNSLOTH_FORCE_FLOAT32", "0") == "0":
autocaster = contextlib.nullcontext()
else:
autocaster = torch.autocast(device_type = "cuda", dtype = torch.float32)
with autocaster:
outputs = self._old_compute_loss(model, inputs, *args, **kwargs)
outputs = self._old_compute_loss(model, inputs, *args, **kwargs)
return outputs
pass

Expand Down Expand Up @@ -1126,7 +1129,9 @@ def patch_fast_lora():


def unsloth_compile_transformers(
dtype,
model_name,
model_types,
token = None,
revision = None,
trust_remote_code = False,
Expand Down Expand Up @@ -1164,15 +1169,12 @@ def unsloth_compile_transformers(
)
return
pass

model_types = get_transformers_model_type(
model_name = model_name,
token = token,
revision = revision,
trust_remote_code = trust_remote_code,
)
model_types = ["siglip"] + model_types

if trust_remote_code:
print(
"Unsloth: We can't trace models if `trust_remote_code = True`, "\
"so turning off some optimizations!"
)
return
if disable: return

for model_type in model_types:
Expand Down Expand Up @@ -1204,6 +1206,9 @@ def unsloth_compile_transformers(
return_logits = return_logits,
)
pass
# Redo patches which override compiler
for temporary_patch in TEMPORARY_PATCHES:
temporary_patch()
return model_types
pass

Expand Down
9 changes: 6 additions & 3 deletions unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1548,7 +1548,7 @@ def unsloth_fast_generate(
if "input_ids" in kwargs and kwargs["input_ids"] is not None and "max_new_tokens" in kwargs:
if kwargs["input_ids"].shape[-1] + kwargs["max_new_tokens"] > self.config.max_position_embeddings:
raise ValueError(
f'Unsloth: input length {kwargs["input_ids"].shape[-1]} + max_new_tokens {kwargs["max_new_tokens"]} exceeds the maximum sequence length of {model.config.max_position_embeddings}!\n'\
f'Unsloth: input length {kwargs["input_ids"].shape[-1]} + max_new_tokens {kwargs["max_new_tokens"]} exceeds the maximum sequence length of {self.config.max_position_embeddings}!\n'\
'You will need to do long context extension by increasing the `max_seq_length` in `FastLanguageModel.from_pretrained`.'
)
pass
Expand All @@ -1562,7 +1562,10 @@ def unsloth_fast_generate(
# For newer HF
kwargs["cache_implementation"] = "dynamic"
# For num_logits_to_keep
kwargs["num_logits_to_keep"] = 1
num_logits_to_keep = kwargs.get("num_logits_to_keep", None)
logits_to_keep = kwargs.get("logits_to_keep", None)
if num_logits_to_keep is None and logits_to_keep is None:
kwargs["num_logits_to_keep"] = 1

# Remove token_type_ids
kwargs.pop("token_type_ids", None)
Expand Down Expand Up @@ -1822,7 +1825,7 @@ def from_pretrained(

# Convert to HF format
_, quant_state_dict = get_vllm_state_dict(llm, config = model_config)
model = convert_vllm_to_huggingface(quant_state_dict, model_config, dtype)
model = convert_vllm_to_huggingface(quant_state_dict, model_config, dtype, bnb_config)
model.vllm_engine = llm
model.fast_generate = model.vllm_engine.generate
model.fast_generate_batches = functools.partial(generate_batches, model.vllm_engine)
Expand Down
68 changes: 58 additions & 10 deletions unsloth/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
HAS_FLASH_ATTENTION,
HAS_FLASH_ATTENTION_SOFTCAPPING,
USE_MODELSCOPE,
get_transformers_model_type,
)
from .granite import FastGraniteModel
from .llama import FastLlamaModel, logger
Expand Down Expand Up @@ -66,6 +67,11 @@
unsloth_compile_transformers,
)

global FORCE_FLOAT32
FORCE_FLOAT32 = [
"gemma3",
]

class FastLanguageModel(FastLlamaModel):
@staticmethod
def from_pretrained(
Expand Down Expand Up @@ -212,7 +218,13 @@ def from_pretrained(
f'Try `pip install --upgrade "transformers>=4.43.2"`\n'\
f"to obtain the latest transformers build, then restart this session."\
)
raise RuntimeError(autoconfig_error or peft_error)
# Create a combined error message showing both failures
combined_error = (
"Unsloth: Failed to load model. Both AutoConfig and PeftConfig loading failed.\n\n"
f"AutoConfig error: {autoconfig_error}\n\n"
f"PeftConfig error: {peft_error}\n\n"
)
raise RuntimeError(combined_error)
pass

# Get base model for PEFT:
Expand Down Expand Up @@ -460,12 +472,17 @@ def from_pretrained(
*args, **kwargs,
):
if token is None: token = get_token()
assert (dtype is None or dtype == torch.float16 or dtype == torch.bfloat16)

SUPPORTS_BFLOAT16 = is_bfloat16_supported()
if dtype is None:
dtype = torch.float16 if not SUPPORTS_BFLOAT16 else torch.bfloat16
elif dtype == torch.bfloat16 and not SUPPORTS_BFLOAT16:
logger.warning_once("Device does not support bfloat16. Will change to float16.")
dtype = torch.float16
assert(dtype in (torch.float16, torch.bfloat16, torch.float32))

patch_compiled_autograd()
patch_compiling_bitsandbytes()
if use_gradient_checkpointing == "unsloth":
patch_unsloth_smart_gradient_checkpointing(dtype = dtype)

if full_finetuning and (load_in_4bit or load_in_8bit):
print("Unsloth: You selected full finetuning support, but 4bit / 8bit is enabled - disabling LoRA / QLoRA.")
Expand All @@ -479,11 +496,6 @@ def from_pretrained(
"Also, we by default set `load_in_4bit = True`.\n"\
"If you want 8bit finetuning, set both `load_in_4bit = False` and `load_in_8bit = True`"
)
if load_in_4bit: pass
elif load_in_8bit: pass
elif not load_in_4bit and not load_in_8bit and not full_finetuning:
print("Unsloth: LoRA, QLoRA and full finetuning all not selected. Switching to QLoRA.")
load_in_4bit = True
pass

old_model_name = model_name
Expand Down Expand Up @@ -591,7 +603,13 @@ def from_pretrained(
f'Try `pip install --upgrade "transformers>=4.43.2"`\n'\
f"to obtain the latest transformers build, then restart this session."\
)
raise RuntimeError(autoconfig_error or peft_error)
# Create a combined error message showing both failures
combined_error = (
"Unsloth: Failed to load model. Both AutoConfig and PeftConfig loading failed.\n\n"
f"AutoConfig error: {autoconfig_error}\n\n"
f"PeftConfig error: {peft_error}\n\n"
)
raise RuntimeError(combined_error)
pass

# Get base model for PEFT:
Expand All @@ -616,10 +634,39 @@ def from_pretrained(
else:
redirector = contextlib.redirect_stdout(open(os.devnull, "w"))

# Get model types like Gemma3 etc
model_types = get_transformers_model_type(
model_name = model_name,
token = token,
revision = revision,
trust_remote_code = trust_remote_code,
)
model_types = ["siglip"] + model_types

# Set forced float32 env flag
os.environ["UNSLOTH_FORCE_FLOAT32"] = "0"
do_forced_float32 = False
model_type_arch = model_types[1]
global FORCE_FLOAT32
for disable_name in FORCE_FLOAT32:
if (disable_name.lower() == model_type_arch.lower() or \
disable_name.lower() in model_name.lower()) and \
((dtype == torch.float16) or not SUPPORTS_BFLOAT16):
os.environ["UNSLOTH_FORCE_FLOAT32"] = "1"
dtype = torch.bfloat16 # Change to bfloat16 loading
break
pass
# Patch gradient checkpointing
if use_gradient_checkpointing == "unsloth":
patch_unsloth_smart_gradient_checkpointing(dtype = dtype)

with redirector:
patch_loss_functions(torch_compile = False)
model_types = unsloth_compile_transformers(
dtype = dtype,
model_name = model_name,
model_types = model_types,
token = token,
sdpa_dynamic_mask = True,
sdpa_bool_masks = True,
sdpa_gqa_replace = True,
Expand All @@ -644,6 +691,7 @@ def from_pretrained(
import_from_cache = False,
disable = False,
return_logits = return_logits,
trust_remote_code = trust_remote_code,
)
pass

Expand Down
1 change: 1 addition & 0 deletions unsloth/models/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
"eval_accumulation_steps" : 2,
"torch_empty_cache_steps" : 250,
"logging_steps" : 1,
"max_seq_length" : None,
}
for k, v in replacements.items():
x = f"{k}( = [^,\n]{{1,}})?,\n"
Expand Down
12 changes: 7 additions & 5 deletions unsloth/models/rl_replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,9 @@ def grpo_trainer__prepare_inputs(function_name, function):

"with torch.inference_mode(), "\
"torch.amp.autocast(device_type = 'cuda', "\
"dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) "\
"if not torch.is_autocast_enabled('cuda') else nullcontext():",
"dtype = ((torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16) "\
"if not torch.is_autocast_enabled('cuda') else nullcontext())"\
"if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '0' else torch.float16):",
)

# Disable attaching a float32 conversion hook which upcasts logits to FP32
Expand Down Expand Up @@ -212,7 +213,7 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep)
# Otherwise, calculate normally:
if not hasattr(self, '_autocast_dtype'):
self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float32
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float16
with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype):
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
Expand Down Expand Up @@ -254,11 +255,12 @@ def compute_loss(self, model, inputs, return_outputs = False, num_items_in_batch
completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
bsz, qlen = input_ids.shape
# attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
attention_mask = None
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
# attention_mask = None
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
_input_ids = input_ids
_logits_to_keep = logits_to_keep

per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)

# Compute the KL divergence between the model and the reference model
Expand Down
Loading