Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
729 commits
Select commit Hold shift + click to select a range
a7c257a
Update dataset_utils.py
danielhanchen Mar 13, 2025
6556f13
Update dataset_utils.py
danielhanchen Mar 13, 2025
9a05b2f
Update compiler.py
danielhanchen Mar 13, 2025
0f6fc7a
Update compiler.py
danielhanchen Mar 13, 2025
4bb152a
Update compiler.py
danielhanchen Mar 13, 2025
3ccdf86
Update compiler.py
danielhanchen Mar 13, 2025
1783ba1
Update loss_utils.py
danielhanchen Mar 13, 2025
b2983f4
Update loss_utils.py
danielhanchen Mar 13, 2025
3c9c731
Merge branch 'main' into nightly
danielhanchen Mar 13, 2025
4c5c77d
gpu_memory_utilization
danielhanchen Mar 14, 2025
b918327
Update temporary_patches.py
danielhanchen Mar 14, 2025
e8f561c
Update vision_utils.py
danielhanchen Mar 14, 2025
4459ef8
Update vision_utils.py
danielhanchen Mar 14, 2025
62e0e14
Update vision_utils.py
danielhanchen Mar 14, 2025
9f4b729
Update vision_utils.py
danielhanchen Mar 14, 2025
29a7abf
Update vision_utils.py
danielhanchen Mar 14, 2025
9830edd
Update vision_utils.py
danielhanchen Mar 14, 2025
be53fda
Update vision_utils.py
danielhanchen Mar 14, 2025
28f4df4
Update vision_utils.py
danielhanchen Mar 14, 2025
ad13d0a
train on completions VLMs
danielhanchen Mar 14, 2025
370cbd7
Update dataset_utils.py
danielhanchen Mar 14, 2025
bd60d26
Update dataset_utils.py
danielhanchen Mar 14, 2025
29ed559
Update dataset_utils.py
danielhanchen Mar 14, 2025
e0a4416
Update dataset_utils.py
danielhanchen Mar 14, 2025
d6e55ca
VLM train only on completions
danielhanchen Mar 14, 2025
adf8307
Update loss_utils.py
danielhanchen Mar 14, 2025
98d5885
Update dataset_utils.py
danielhanchen Mar 14, 2025
967c2ba
Update compiler.py
danielhanchen Mar 14, 2025
ddf2b8e
Update compiler.py
danielhanchen Mar 14, 2025
cb2f6c7
Update compiler.py
danielhanchen Mar 14, 2025
873c514
Update compiler.py
danielhanchen Mar 14, 2025
ca0b499
Update compiler.py
danielhanchen Mar 14, 2025
4908a16
Update compiler.py
danielhanchen Mar 14, 2025
1d4b5d7
Update compiler.py
danielhanchen Mar 14, 2025
81b45c6
Update saving_utils.py
danielhanchen Mar 14, 2025
261ffd2
Update llama_cpp.py
danielhanchen Mar 14, 2025
2ed281a
Update llama_cpp.py
danielhanchen Mar 14, 2025
d89a8fa
Update saving_utils.py
danielhanchen Mar 14, 2025
106736a
Update saving_utils.py
danielhanchen Mar 14, 2025
4abfdcd
Update __init__.py
danielhanchen Mar 14, 2025
0ac4464
Update compiler.py
danielhanchen Mar 14, 2025
e2fbe79
Update loss_utils.py
danielhanchen Mar 14, 2025
82665d4
Update compiler.py
danielhanchen Mar 14, 2025
9b6142e
Update loss_utils.py
danielhanchen Mar 14, 2025
ee92817
Update loss_utils.py
danielhanchen Mar 14, 2025
9b7600d
Update llama_cpp.py
danielhanchen Mar 14, 2025
d5b6d1c
Update loss_utils.py
danielhanchen Mar 14, 2025
86516ad
Update compiler.py
danielhanchen Mar 14, 2025
29553e4
Update llama_cpp.py
danielhanchen Mar 14, 2025
33e6c8e
Update compiler.py
danielhanchen Mar 14, 2025
5202605
Update vllm_utils.py
danielhanchen Mar 14, 2025
ca52896
Update rl_replacements.py
danielhanchen Mar 14, 2025
7baa442
Update rl_replacements.py
danielhanchen Mar 14, 2025
7ff5a1a
Update rl_replacements.py
danielhanchen Mar 14, 2025
e80aa10
Update rl_replacements.py
danielhanchen Mar 14, 2025
e93d93f
Update rl_replacements.py
danielhanchen Mar 14, 2025
9a6c231
Update rl_replacements.py
danielhanchen Mar 14, 2025
c8abd45
Update training_utils.py
danielhanchen Mar 14, 2025
1633c78
Merge branch 'main' into nightly
danielhanchen Mar 15, 2025
964129b
Update dataset_utils.py
danielhanchen Mar 15, 2025
3b690ad
Update dataset_utils.py
danielhanchen Mar 16, 2025
7bb4a13
Revert "Update dataset_utils.py"
danielhanchen Mar 16, 2025
947c5e9
Update temporary_patches.py
danielhanchen Mar 16, 2025
2fe9c6c
Update temporary_patches.py
danielhanchen Mar 16, 2025
0b2dc97
Update temporary_patches.py
danielhanchen Mar 16, 2025
b9a96dc
Update temporary_patches.py
danielhanchen Mar 16, 2025
0784a07
Update temporary_patches.py
danielhanchen Mar 16, 2025
80c2dc8
Update temporary_patches.py
danielhanchen Mar 16, 2025
26c817d
Update compiler.py
danielhanchen Mar 16, 2025
d3cdd17
Update compiler.py
danielhanchen Mar 16, 2025
31e778a
Remove prints
danielhanchen Mar 16, 2025
2c6a3c5
Update compiler.py
danielhanchen Mar 16, 2025
f3f3c9c
Update saving_utils.py
danielhanchen Mar 16, 2025
93b6a88
Update temporary_patches.py
danielhanchen Mar 16, 2025
86aee5c
Update __init__.py
danielhanchen Mar 16, 2025
ac38bff
Update pyproject.toml
danielhanchen Mar 16, 2025
f64e153
Update vllm_utils.py
danielhanchen Mar 16, 2025
4c72e79
bug fix #2008 unsloth issue - load_in_4bit = True + fast_inference = …
void-mckenzie Mar 16, 2025
1974798
Update dataset_utils.py
danielhanchen Mar 16, 2025
4df4417
Merge branch 'nightly' of https://github.com/unslothai/unsloth-zoo in…
danielhanchen Mar 16, 2025
a5c20e1
Update compiler.py
danielhanchen Mar 17, 2025
a434d45
Update temporary_patches.py
danielhanchen Mar 17, 2025
3cfb98f
Gemma 3 fixes
danielhanchen Mar 17, 2025
fc5f1c0
Update temporary_patches.py
danielhanchen Mar 17, 2025
b317e90
Update compiler.py
danielhanchen Mar 17, 2025
4121dd0
Update compiler.py
danielhanchen Mar 17, 2025
c59dcde
Gemma 3 fixes
danielhanchen Mar 17, 2025
d98ae2e
Update patching_utils.py
danielhanchen Mar 17, 2025
3073ea3
Update compiler.py
danielhanchen Mar 17, 2025
57ff5f6
Update compiler.py
danielhanchen Mar 17, 2025
c7e803b
Update patching_utils.py
danielhanchen Mar 17, 2025
3daaf0d
Update temporary_patches.py
danielhanchen Mar 17, 2025
b619b58
Update compiler.py
danielhanchen Mar 17, 2025
4e78082
Update compiler.py
danielhanchen Mar 17, 2025
c8ba677
Update temporary_patches.py
danielhanchen Mar 17, 2025
fb68ecc
Update temporary_patches.py
danielhanchen Mar 17, 2025
e5a73fe
Update temporary_patches.py
danielhanchen Mar 17, 2025
d7bbe30
Update temporary_patches.py
danielhanchen Mar 17, 2025
5f99275
Update temporary_patches.py
danielhanchen Mar 17, 2025
346812f
Update temporary_patches.py
danielhanchen Mar 17, 2025
b907d0c
Update temporary_patches.py
danielhanchen Mar 17, 2025
789171c
Update temporary_patches.py
danielhanchen Mar 17, 2025
4e2c94a
Update compiler.py
danielhanchen Mar 17, 2025
4740c99
Update compiler.py
danielhanchen Mar 17, 2025
4658d94
Update compiler.py
danielhanchen Mar 17, 2025
f9de6e9
Update compiler.py
danielhanchen Mar 17, 2025
dbdbc63
Update compiler.py
danielhanchen Mar 17, 2025
55b1963
Update compiler.py
danielhanchen Mar 17, 2025
e997ee1
Update compiler.py
danielhanchen Mar 17, 2025
0ba033f
Update compiler.py
danielhanchen Mar 17, 2025
bf821ba
Update compiler.py
danielhanchen Mar 17, 2025
d8c6e59
Update compiler.py
danielhanchen Mar 17, 2025
9967ce3
Update compiler.py
danielhanchen Mar 17, 2025
7b0c535
Update compiler.py
danielhanchen Mar 17, 2025
e6859ce
Update compiler.py
danielhanchen Mar 17, 2025
b2a8f47
Update compiler.py
danielhanchen Mar 17, 2025
ca79c93
Update compiler.py
danielhanchen Mar 17, 2025
3f67ed6
Update compiler.py
danielhanchen Mar 17, 2025
e5fb044
Update compiler.py
danielhanchen Mar 18, 2025
4a1bf2f
Update compiler.py
danielhanchen Mar 18, 2025
36ec4ee
Update compiler.py
danielhanchen Mar 18, 2025
7d1dc81
compiler
danielhanchen Mar 18, 2025
16d6137
Update gradient_checkpointing.py
danielhanchen Mar 18, 2025
9b78566
Update temporary_patches.py
danielhanchen Mar 18, 2025
e0edefe
Update temporary_patches.py
danielhanchen Mar 18, 2025
719e379
Update temporary_patches.py
danielhanchen Mar 18, 2025
8beb2b7
Update temporary_patches.py
danielhanchen Mar 18, 2025
f9cf701
Update temporary_patches.py
danielhanchen Mar 18, 2025
aa8848c
Update temporary_patches.py
danielhanchen Mar 18, 2025
ee940a9
Update temporary_patches.py
danielhanchen Mar 18, 2025
d086158
Update temporary_patches.py
danielhanchen Mar 18, 2025
5a43de2
Update temporary_patches.py
danielhanchen Mar 18, 2025
1f6589b
Update temporary_patches.py
danielhanchen Mar 18, 2025
9b904a9
Update temporary_patches.py
danielhanchen Mar 18, 2025
3c0504b
Update temporary_patches.py
danielhanchen Mar 18, 2025
417161e
Update temporary_patches.py
danielhanchen Mar 18, 2025
3f024b6
Update temporary_patches.py
danielhanchen Mar 18, 2025
b0bd2f4
Update temporary_patches.py
danielhanchen Mar 18, 2025
640e071
Update temporary_patches.py
danielhanchen Mar 18, 2025
05c2232
Update temporary_patches.py
danielhanchen Mar 18, 2025
593eecb
Update temporary_patches.py
danielhanchen Mar 18, 2025
e9c935f
Update temporary_patches.py
danielhanchen Mar 18, 2025
b71160c
causal mask dtype
danielhanchen Mar 18, 2025
a6fedb6
Fix checkpoint and save from local file (#74)
Erland366 Mar 18, 2025
c566b02
Update patching_utils.py
danielhanchen Mar 18, 2025
aaf5feb
Merge branch 'nightly' of https://github.com/unslothai/unsloth-zoo in…
danielhanchen Mar 18, 2025
94f5f4f
Update patching_utils.py
danielhanchen Mar 18, 2025
26c67cf
Update temporary_patches.py
danielhanchen Mar 18, 2025
d92bab6
Update temporary_patches.py
danielhanchen Mar 18, 2025
b04cf4b
Update compiler.py
danielhanchen Mar 18, 2025
4565db3
Update loss_utils.py
danielhanchen Mar 18, 2025
e368810
Update compiler.py
danielhanchen Mar 18, 2025
ce07e0f
Update vllm_utils.py
danielhanchen Mar 18, 2025
114150d
Update compiler.py
danielhanchen Mar 18, 2025
6bd69f1
Update peft_utils.py
danielhanchen Mar 18, 2025
9cee216
Update rl_replacements.py
danielhanchen Mar 18, 2025
df8ac03
Update vllm_utils.py
danielhanchen Mar 18, 2025
e5a321f
Update temporary_patches.py
danielhanchen Mar 18, 2025
134857d
Update temporary_patches.py
danielhanchen Mar 18, 2025
dec6433
Update temporary_patches.py
danielhanchen Mar 18, 2025
b14149b
Update temporary_patches.py
danielhanchen Mar 18, 2025
07f7dde
Update temporary_patches.py
danielhanchen Mar 18, 2025
7600d35
Update temporary_patches.py
danielhanchen Mar 18, 2025
679edeb
Update temporary_patches.py
danielhanchen Mar 18, 2025
5fd25ec
Update temporary_patches.py
danielhanchen Mar 18, 2025
a884b3c
Update temporary_patches.py
danielhanchen Mar 18, 2025
b6ab8bd
Update temporary_patches.py
danielhanchen Mar 18, 2025
cc3ca48
Update temporary_patches.py
danielhanchen Mar 18, 2025
9f5b67d
Update temporary_patches.py
danielhanchen Mar 18, 2025
e4980b2
Update temporary_patches.py
danielhanchen Mar 18, 2025
d745fb7
Update temporary_patches.py
danielhanchen Mar 18, 2025
201c1ab
Merge branch 'main' into nightly
danielhanchen Mar 18, 2025
2fb83f0
Update compiler.py
danielhanchen Mar 18, 2025
3551715
Update vllm_lora_worker_manager.py
danielhanchen Mar 19, 2025
ab47b77
Update utils.py
danielhanchen Mar 19, 2025
ceed6ab
Update temporary_patches.py
danielhanchen Mar 19, 2025
b5611c2
Update temporary_patches.py
danielhanchen Mar 19, 2025
480aaf7
Update temporary_patches.py
danielhanchen Mar 19, 2025
637c7ad
Update temporary_patches.py
danielhanchen Mar 19, 2025
5a224bb
Update temporary_patches.py
danielhanchen Mar 19, 2025
2248156
Update temporary_patches.py
danielhanchen Mar 19, 2025
ee6ed2b
Update temporary_patches.py
danielhanchen Mar 19, 2025
6d10b9b
Update temporary_patches.py
danielhanchen Mar 19, 2025
9d431b0
Update temporary_patches.py
danielhanchen Mar 19, 2025
42491ca
Update temporary_patches.py
danielhanchen Mar 19, 2025
d3ddadf
Update vllm_utils.py
danielhanchen Mar 19, 2025
dbc6a43
Update vllm_utils.py
danielhanchen Mar 19, 2025
0c4b0d2
Update vllm_utils.py
danielhanchen Mar 19, 2025
5504033
Update vllm_utils.py
danielhanchen Mar 19, 2025
2a84e79
Update dataset_utils.py
danielhanchen Mar 19, 2025
cbbc4a3
bidirectional attention
danielhanchen Mar 19, 2025
3bf532d
Update vllm_utils.py
danielhanchen Mar 19, 2025
8e687b5
Update __init__.py
danielhanchen Mar 19, 2025
a723520
Update temporary_patches.py
danielhanchen Mar 19, 2025
9d1dd42
Update temporary_patches.py
danielhanchen Mar 19, 2025
aec2701
Update temporary_patches.py
danielhanchen Mar 19, 2025
23a3a59
Update vllm_utils.py
danielhanchen Mar 19, 2025
2874477
Update vllm_utils.py
danielhanchen Mar 19, 2025
7d40491
Update vllm_utils.py
danielhanchen Mar 19, 2025
2275642
Update vllm_utils.py
danielhanchen Mar 19, 2025
9cd348f
Update vllm_utils.py
danielhanchen Mar 19, 2025
6e33fa9
Update vllm_utils.py
danielhanchen Mar 19, 2025
7ad0f55
Update vllm_lora_worker_manager.py
danielhanchen Mar 19, 2025
7fd23a0
Update vllm_lora_worker_manager.py
danielhanchen Mar 19, 2025
9176758
Update vllm_lora_worker_manager.py
danielhanchen Mar 19, 2025
b5a38b0
Merge branch 'main' into nightly
danielhanchen Mar 19, 2025
446787d
Merge branch 'main' into nightly
danielhanchen Mar 19, 2025
d2bdd9b
Update temporary_patches.py
danielhanchen Mar 19, 2025
83bde7d
Update temporary_patches.py
danielhanchen Mar 19, 2025
0fe9eaa
Update temporary_patches.py
danielhanchen Mar 19, 2025
3d70a80
Update temporary_patches.py
danielhanchen Mar 19, 2025
6b6587d
Merge branch 'main' into nightly
danielhanchen Mar 21, 2025
88301c5
Update loss_utils.py
danielhanchen Mar 21, 2025
debc0e8
Update loss_utils.py
danielhanchen Mar 21, 2025
7dc2e9d
Update loss_utils.py
danielhanchen Mar 21, 2025
57b4973
Update loss_utils.py
danielhanchen Mar 21, 2025
3cfa271
Update loss_utils.py
danielhanchen Mar 21, 2025
1f5b6f2
Update __init__.py
danielhanchen Mar 21, 2025
2f3c87b
fix: AsyncLLMEngine bugs (#82)
bradhilton Mar 22, 2025
64dd76c
fixed a typo in L119, removing unnecessary len() (#84)
SpaceHunterInf Mar 22, 2025
5a1a2b5
Merge branch 'main' into nightly
danielhanchen Mar 22, 2025
a62e4c6
Fix gradient checkpointing warning filter implementation
rolandtannous Mar 24, 2025
d115cea
Input grads fix for gemma3 (#96)
mmathew23 Mar 25, 2025
454757c
Merge pull request #97 from rolandtannous/fix/suppress-gradient-check…
shimmyshimmer Mar 25, 2025
c50123a
Update vision_utils.py
danielhanchen Mar 26, 2025
b199491
Vision requires grad
danielhanchen Mar 26, 2025
1670fa6
Check SDPA for Mistral / Pixtral
danielhanchen Mar 26, 2025
e32f797
Update compiler.py
danielhanchen Mar 26, 2025
b9d9cc5
Update vision_utils.py
danielhanchen Mar 26, 2025
5e3c88f
Update vision_utils.py
danielhanchen Mar 26, 2025
5c4086c
Update vision_utils.py
danielhanchen Mar 26, 2025
0599242
Update __init__.py
danielhanchen Mar 26, 2025
8da5939
Update vision_utils.py
danielhanchen Mar 26, 2025
db90dca
Update vision_utils.py
danielhanchen Mar 26, 2025
51cefe5
Update vision_utils.py
danielhanchen Mar 26, 2025
20b42ce
Update vision_utils.py
danielhanchen Mar 26, 2025
b03ded6
Update vision_utils.py
danielhanchen Mar 26, 2025
65469c2
Update vision_utils.py
danielhanchen Mar 26, 2025
7f4eb00
Update vision_utils.py
danielhanchen Mar 26, 2025
8584b5d
Update vision_utils.py
danielhanchen Mar 26, 2025
8bb6b55
Update vision_utils.py
danielhanchen Mar 26, 2025
20a61b0
Update vision_utils.py
danielhanchen Mar 26, 2025
86ca6d5
Update vision_utils.py
danielhanchen Mar 26, 2025
0221094
Update vision_utils.py
danielhanchen Mar 26, 2025
d13ebf7
Update vision_utils.py
danielhanchen Mar 26, 2025
f6c4b2e
Update vision_utils.py
danielhanchen Mar 26, 2025
2d1e506
Update vllm_utils.py (#99)
5k5000 Mar 26, 2025
23e018f
Update vision_utils.py
danielhanchen Mar 26, 2025
9f1eaa2
Fixes to support IterableDataset (#98)
marcandrelarochelle Mar 26, 2025
affb9d8
Merge branch 'nightly' of https://github.com/unslothai/unsloth-zoo in…
danielhanchen Mar 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion unsloth_zoo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

__version__ = "2025.3.16"
__version__ = "2025.3.17"

from importlib.util import find_spec
if find_spec("unsloth") is None:
Expand Down
40 changes: 32 additions & 8 deletions unsloth_zoo/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,7 +1000,7 @@ def apply_fused_lm_head(forward):

cross_entropy_replacement = cross_entropy_replacement\
.replace(
"$KWARGS$",
"$KWARGS$",
"locals().get('loss_kwargs', {}) or locals().get('kwargs', {})"
)

Expand Down Expand Up @@ -1179,7 +1179,7 @@ def patch_gradient_checkpointing(module, source):
.replace("LAYER", layer).replace("MODULELIST_ITEM", modulelist_item)\
.replace("ARGS", args).replace("$", spaces)
forward = forward.replace(forward[span[0] : span[1]], replacer)

# Also fix init
spaces = init.find("def")
init = init + "\n" + (spaces + 4) * " " + "self.gradient_checkpointing = False\n\n"
Expand Down Expand Up @@ -1381,10 +1381,10 @@ def patch_gradient_accumulation(modeling_file, module):

functions = dir(modeling_file)
module = eval(f"modeling_file.{module}")
try:
try:
forward = module.forward
source = inspect.getsource(forward)
except:
except:
return None
has_kwargs = tuple(inspect.signature(forward).parameters.values())[-1].kind == inspect._VAR_KEYWORD
if has_kwargs: return None
Expand Down Expand Up @@ -1449,7 +1449,12 @@ def unsloth_compile_transformers(
import_from_cache : bool = False,
disable : bool = False,
return_logits : bool = False,
supports_sdpa : list = None,
):
# import transformers logging module and instantiate model_type logging instance.
from transformers import logging as transformers_logging
model_logger = transformers_logging.get_logger(f"modeling_{model_type}")

# All Unsloth Zoo code licensed under LGPLv3
disable = disable or (os.environ.get("UNSLOTH_COMPILE_DISABLE", "0") == "1")
if fast_residual_stream:
Expand All @@ -1461,8 +1466,8 @@ def unsloth_compile_transformers(
modeling_file = eval(model_location)
if hasattr(modeling_file, "__UNSLOTH_PATCHED__"): return

# Remove `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`
exec("modeling_file.logger.addFilter(HideLoggingMessage('Setting `use_cache=False`'))", globals(), locals())
# Use transformers model_type logger to supress message: Remove `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`
exec("model_logger.addFilter(HideLoggingMessage('Setting `use_cache=False`'))", globals(), locals())

# torch_compile_options
UNSLOTH_COMPILE_DEBUG = os.environ.get("UNSLOTH_COMPILE_DEBUG", "0") == "1"
Expand All @@ -1489,7 +1494,7 @@ def unsloth_compile_transformers(
if "UNSLOTH_FULLGRAPH" not in os.environ:
os.environ["UNSLOTH_FULLGRAPH"] = UNSLOTH_FULLGRAPH
else:
UNSLOTH_FULLGRAPH = os.environ["UNSLOTH_FULLGRAPH"] == "1"
UNSLOTH_FULLGRAPH = os.environ["UNSLOTH_FULLGRAPH"]
pass
UNSLOTH_FULLGRAPH = UNSLOTH_FULLGRAPH == "1"

Expand Down Expand Up @@ -1547,6 +1552,17 @@ def unsloth_compile_transformers(
)
torch_modules = [x for x in torch_modules if x not in removal]

# Check SDPA to load as eager or SDPA (Pixtral / Mistral 3 for eg doesn't have SDPA)
if supports_sdpa is not None:
assert(type(supports_sdpa) is list and len(supports_sdpa) == 1)
if len(scaled_dot_product_attention_modules) != 0:
if supports_sdpa[0] != False: supports_sdpa[0] = True
elif "_supports_sdpa = True" in full_source:
if supports_sdpa[0] != False: supports_sdpa[0] = True
else:
supports_sdpa[0] = False
pass

# Get functions which are called
called_functions = []
for function in functions:
Expand All @@ -1566,6 +1582,14 @@ def unsloth_compile_transformers(
except: continue
fullgraph = not ("nn.Linear" in source or "nn.ModuleList" in source)

# Eg SiglipVisionEmbeddings and CLIPVisionEmbeddings
if str(module).endswith("VisionEmbeddings"):
# sometimes we attach a post forward call to make sure requires grad is set
# this breaks full graph mode and fails so instead we relax the full graph check
# We attach via post forward call, since the forward call only passes keyword
# arguments in transformers and pre_forward hook doesn't pass kwargs.
fullgraph = False

# Check if other modules is used as well
for another_module in torch_modules:
if another_module in source:
Expand Down Expand Up @@ -1792,7 +1816,7 @@ def unsloth_compile_transformers(
# Disable if torch < 2.5 or V100s 7.0 (Tesla T4 7.5 works) or old Triton < 3
if OLD_CUDA_ARCH_VERSION or OLD_TORCH_VERSION or OLD_TRITON_VERSION:
continue

module_class = eval(f"modeling_file.{module}")
if hasattr(module_class, "forward") and issubclass(module_class, GenerationMixin):
try:
Expand Down
25 changes: 19 additions & 6 deletions unsloth_zoo/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,10 @@ def _train_on_responses_only(examples):
if hasattr(trainer, "train_dataset") and trainer.train_dataset is not None:
if not hasattr(trainer.train_dataset, "map"):
raise TypeError("Unsloth: train_on_responses_only does not work on lists!")
trainer.train_dataset = trainer.train_dataset.map(_train_on_responses_only, batched = True, num_proc = num_proc)
if isinstance(trainer.train_dataset, IterableDataset):
trainer.train_dataset = trainer.train_dataset.map(_train_on_responses_only, batch_size = trainer.train_dataset._ex_iterable.batch_size, batched = True)
else:
trainer.train_dataset = trainer.train_dataset.map(_train_on_responses_only, batched = True, num_proc = num_proc)
pass

if hasattr(trainer, "eval_dataset") and trainer.eval_dataset is not None:
Expand All @@ -343,11 +346,17 @@ def _train_on_responses_only(examples):
for key, value in trainer.eval_dataset.items():
if not hasattr(value, "map"):
raise TypeError("Unsloth: train_on_responses_only does not work on lists!")
trainer.eval_dataset[key] = value.map(_train_on_responses_only, batched = True, num_proc = num_proc)
if isinstance(trainer.eval_dataset, IterableDataset):
trainer.eval_dataset[key] = value.map(_train_on_responses_only, batch_size = trainer.eval_dataset._ex_iterable.batch_size, batched = True)
else:
trainer.eval_dataset[key] = value.map(_train_on_responses_only, batched = True, num_proc = num_proc)
else:
if not hasattr(trainer.eval_dataset, "map"):
raise TypeError("Unsloth: train_on_responses_only does not work on lists!")
trainer.eval_dataset = trainer.eval_dataset.map(_train_on_responses_only, batched = True, num_proc = num_proc)
if isinstance(trainer.eval_dataset, IterableDataset):
trainer.eval_dataset = trainer.eval_dataset.map(_train_on_responses_only, batch_size = trainer.eval_dataset._ex_iterable.batch_size, batched = True)
else:
trainer.eval_dataset = trainer.eval_dataset.map(_train_on_responses_only, batched = True, num_proc = num_proc)
pass
pass

Expand Down Expand Up @@ -531,14 +540,14 @@ def sft_prepare_dataset(
if do_tokenize:
# Check double BOS tokens
if do_formatting_func:
test_text = formatting_func(dataset[0])
test_text = formatting_func(next(iter(dataset)))
if not isinstance(test_text, list):
raise ValueError(
"Unsloth: The `formatting_func` should return a list of processed strings."
)
test_text = test_text[0]
else:
test_text = dataset[0][dataset_text_field]
test_text = next(iter(dataset))[dataset_text_field][0]

# Get chat template
chat_template = getattr(processing_class, 'chat_template', '')
Expand Down Expand Up @@ -570,7 +579,11 @@ def _tokenize(example):
)
pass

map_kwargs["num_proc"] = getattr(args, "dataset_num_proc", 2)
if not isinstance(dataset, IterableDataset):
map_kwargs["num_proc"] = getattr(args, "dataset_num_proc", 2)
else:
map_kwargs["batch_size"] = dataset._ex_iterable.batch_size

if use_desc: map_kwargs["desc"] = f'Unsloth: Tokenizing ["{dataset_text_field}"]'
dataset = dataset.map(_tokenize, batched = True, **map_kwargs)

Expand Down
2 changes: 1 addition & 1 deletion unsloth_zoo/peft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def requires_grad_pre_hook(module, input):
module_name = "model." + ".".join(name_components[:final_where])
module = eval(module_name)

if hasattr(module, "config") and module.config.__class__.__name__ == "CLIPVisionConfig":
if hasattr(module, "config") and (module.config.__class__.__name__ in ("CLIPVisionConfig", "SiglipVisionConfig",)):
# CLIP - backtrack to get_input_embeddings since requires_grad fails!
old_module = model
for module_name, module in model.named_modules():
Expand Down
48 changes: 43 additions & 5 deletions unsloth_zoo/vision_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,13 +262,13 @@ class UnslothVisionDataCollator:
"padding_token_ids", "dtype", "ignore_index", \
"processor", "formatting_func", "image_size", \
"max_seq_length", "truncation", "train_on_responses_only", \
"num_proc",
"num_proc", "assistant_single_content",

def __init__(
self,
model,
processor,
max_seq_length = None,
max_seq_length = None,
formatting_func = None,
resize = "min", # Can be (10, 10) or "min" to resize to fit
# the model's default image_size or "max"
Expand Down Expand Up @@ -335,6 +335,36 @@ def __init__(
)
else:
self.train_on_responses_only = None

# Check what type for assistant VLM tokenizer allows!
# Good for Mistral V3 and Pixtral I think
try:
processor.apply_chat_template([
{"role": "user", "content": [
{"type": "image"},
{"type": "text", "text": "Hello!"}]},
{"role": "assistant", "content": [
{"type": "text", "text": "How can I help you?"}]}
])
self.assistant_single_content = False
except TypeError:
try:
processor.apply_chat_template([
{"role": "user", "content": [
{"type": "image"},
{"type": "text", "text": "Hello!"}]},
{"role": "assistant", "content": "How can I help you?"}
])
self.assistant_single_content = True
print(
f"Unsloth: {processor.__class__.__name__} only accepts 1 "\
"text field for assistant roles!\n"\
"We will auto fix the data collator to support it!"
)
except Exception as e:
raise RuntimeError(e)
except Exception as e:
raise RuntimeError(e)
return
pass

Expand Down Expand Up @@ -366,7 +396,7 @@ def __call__(self, examples):
)
content = message["content"]
if type(content) is str:
message["content"] = [{"type" : "text", "text" : content}]
message["content"] = content = [{"type" : "text", "text" : content}]
elif type(content) is list or type(content) is tuple:
part = content[0]
assert("type" in part)
Expand All @@ -377,6 +407,15 @@ def __call__(self, examples):
"[{'role':'user', 'content':[{'type':'text', 'text':'Hello!'}]}]"
)
pass

# Also fix the messages if assistant must only be 1 string!
# Only affects Mistral V3 I think!
if self.assistant_single_content:
for message in messages:
if message["role"] == "assistant":
if type(content := message["content"]) is list:
message["content"] = content[0]["text"]
pass
pass
message = self.processor.apply_chat_template(
messages,
Expand Down Expand Up @@ -417,7 +456,7 @@ def __call__(self, examples):
return_tensors = "pt",
add_special_tokens = False, # Stop double BOS
)
# Cannot remove due to bidirectional attention fro Gemma 3!
# Cannot remove due to bidirectional attention from Gemma 3!
# batch.pop("token_type_ids", None)

# Pixtral accepts multiple images, so we have to cast it individually
Expand All @@ -439,7 +478,6 @@ def __call__(self, examples):
labels = batch["input_ids"].clone()
labels[torch.isin(labels, self.padding_token_ids)] = self.ignore_index
batch["labels"] = labels

if self.train_on_responses_only:
batch["labels"] = self.train_on_responses_only(batch)["labels"]
return batch
Expand Down
6 changes: 3 additions & 3 deletions unsloth_zoo/vllm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1346,12 +1346,12 @@ def generate_batches(llm, inputs, n_batches = None, lora_request = None, *args,

batches = create_batches(inputs, n_batches)
kwargs["lora_request"] = lora_request
outputs = []
output_list = []
for batch in batches:
outputs = llm.generate(batch, *args, **kwargs)
outputs += list(outputs)
output_list += list(outputs)
pass
return outputs
return output_list
pass


Expand Down