Skip to content
Merged

TTS #141

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
854 commits
Select commit Hold shift + click to select a range
719e379
Update temporary_patches.py
danielhanchen Mar 18, 2025
8beb2b7
Update temporary_patches.py
danielhanchen Mar 18, 2025
f9cf701
Update temporary_patches.py
danielhanchen Mar 18, 2025
aa8848c
Update temporary_patches.py
danielhanchen Mar 18, 2025
ee940a9
Update temporary_patches.py
danielhanchen Mar 18, 2025
d086158
Update temporary_patches.py
danielhanchen Mar 18, 2025
5a43de2
Update temporary_patches.py
danielhanchen Mar 18, 2025
1f6589b
Update temporary_patches.py
danielhanchen Mar 18, 2025
9b904a9
Update temporary_patches.py
danielhanchen Mar 18, 2025
3c0504b
Update temporary_patches.py
danielhanchen Mar 18, 2025
417161e
Update temporary_patches.py
danielhanchen Mar 18, 2025
3f024b6
Update temporary_patches.py
danielhanchen Mar 18, 2025
b0bd2f4
Update temporary_patches.py
danielhanchen Mar 18, 2025
640e071
Update temporary_patches.py
danielhanchen Mar 18, 2025
05c2232
Update temporary_patches.py
danielhanchen Mar 18, 2025
593eecb
Update temporary_patches.py
danielhanchen Mar 18, 2025
e9c935f
Update temporary_patches.py
danielhanchen Mar 18, 2025
b71160c
causal mask dtype
danielhanchen Mar 18, 2025
a6fedb6
Fix checkpoint and save from local file (#74)
Erland366 Mar 18, 2025
c566b02
Update patching_utils.py
danielhanchen Mar 18, 2025
aaf5feb
Merge branch 'nightly' of https://github.com/unslothai/unsloth-zoo in…
danielhanchen Mar 18, 2025
94f5f4f
Update patching_utils.py
danielhanchen Mar 18, 2025
26c67cf
Update temporary_patches.py
danielhanchen Mar 18, 2025
d92bab6
Update temporary_patches.py
danielhanchen Mar 18, 2025
b04cf4b
Update compiler.py
danielhanchen Mar 18, 2025
4565db3
Update loss_utils.py
danielhanchen Mar 18, 2025
e368810
Update compiler.py
danielhanchen Mar 18, 2025
ce07e0f
Update vllm_utils.py
danielhanchen Mar 18, 2025
114150d
Update compiler.py
danielhanchen Mar 18, 2025
6bd69f1
Update peft_utils.py
danielhanchen Mar 18, 2025
9cee216
Update rl_replacements.py
danielhanchen Mar 18, 2025
df8ac03
Update vllm_utils.py
danielhanchen Mar 18, 2025
e5a321f
Update temporary_patches.py
danielhanchen Mar 18, 2025
134857d
Update temporary_patches.py
danielhanchen Mar 18, 2025
dec6433
Update temporary_patches.py
danielhanchen Mar 18, 2025
b14149b
Update temporary_patches.py
danielhanchen Mar 18, 2025
07f7dde
Update temporary_patches.py
danielhanchen Mar 18, 2025
7600d35
Update temporary_patches.py
danielhanchen Mar 18, 2025
679edeb
Update temporary_patches.py
danielhanchen Mar 18, 2025
5fd25ec
Update temporary_patches.py
danielhanchen Mar 18, 2025
a884b3c
Update temporary_patches.py
danielhanchen Mar 18, 2025
b6ab8bd
Update temporary_patches.py
danielhanchen Mar 18, 2025
cc3ca48
Update temporary_patches.py
danielhanchen Mar 18, 2025
9f5b67d
Update temporary_patches.py
danielhanchen Mar 18, 2025
e4980b2
Update temporary_patches.py
danielhanchen Mar 18, 2025
d745fb7
Update temporary_patches.py
danielhanchen Mar 18, 2025
201c1ab
Merge branch 'main' into nightly
danielhanchen Mar 18, 2025
2fb83f0
Update compiler.py
danielhanchen Mar 18, 2025
3551715
Update vllm_lora_worker_manager.py
danielhanchen Mar 19, 2025
ab47b77
Update utils.py
danielhanchen Mar 19, 2025
ceed6ab
Update temporary_patches.py
danielhanchen Mar 19, 2025
b5611c2
Update temporary_patches.py
danielhanchen Mar 19, 2025
480aaf7
Update temporary_patches.py
danielhanchen Mar 19, 2025
637c7ad
Update temporary_patches.py
danielhanchen Mar 19, 2025
5a224bb
Update temporary_patches.py
danielhanchen Mar 19, 2025
2248156
Update temporary_patches.py
danielhanchen Mar 19, 2025
ee6ed2b
Update temporary_patches.py
danielhanchen Mar 19, 2025
6d10b9b
Update temporary_patches.py
danielhanchen Mar 19, 2025
9d431b0
Update temporary_patches.py
danielhanchen Mar 19, 2025
42491ca
Update temporary_patches.py
danielhanchen Mar 19, 2025
d3ddadf
Update vllm_utils.py
danielhanchen Mar 19, 2025
dbc6a43
Update vllm_utils.py
danielhanchen Mar 19, 2025
0c4b0d2
Update vllm_utils.py
danielhanchen Mar 19, 2025
5504033
Update vllm_utils.py
danielhanchen Mar 19, 2025
2a84e79
Update dataset_utils.py
danielhanchen Mar 19, 2025
cbbc4a3
bidirectional attention
danielhanchen Mar 19, 2025
3bf532d
Update vllm_utils.py
danielhanchen Mar 19, 2025
8e687b5
Update __init__.py
danielhanchen Mar 19, 2025
a723520
Update temporary_patches.py
danielhanchen Mar 19, 2025
9d1dd42
Update temporary_patches.py
danielhanchen Mar 19, 2025
aec2701
Update temporary_patches.py
danielhanchen Mar 19, 2025
23a3a59
Update vllm_utils.py
danielhanchen Mar 19, 2025
2874477
Update vllm_utils.py
danielhanchen Mar 19, 2025
7d40491
Update vllm_utils.py
danielhanchen Mar 19, 2025
2275642
Update vllm_utils.py
danielhanchen Mar 19, 2025
9cd348f
Update vllm_utils.py
danielhanchen Mar 19, 2025
6e33fa9
Update vllm_utils.py
danielhanchen Mar 19, 2025
7ad0f55
Update vllm_lora_worker_manager.py
danielhanchen Mar 19, 2025
7fd23a0
Update vllm_lora_worker_manager.py
danielhanchen Mar 19, 2025
9176758
Update vllm_lora_worker_manager.py
danielhanchen Mar 19, 2025
b5a38b0
Merge branch 'main' into nightly
danielhanchen Mar 19, 2025
446787d
Merge branch 'main' into nightly
danielhanchen Mar 19, 2025
d2bdd9b
Update temporary_patches.py
danielhanchen Mar 19, 2025
83bde7d
Update temporary_patches.py
danielhanchen Mar 19, 2025
0fe9eaa
Update temporary_patches.py
danielhanchen Mar 19, 2025
3d70a80
Update temporary_patches.py
danielhanchen Mar 19, 2025
6b6587d
Merge branch 'main' into nightly
danielhanchen Mar 21, 2025
88301c5
Update loss_utils.py
danielhanchen Mar 21, 2025
debc0e8
Update loss_utils.py
danielhanchen Mar 21, 2025
7dc2e9d
Update loss_utils.py
danielhanchen Mar 21, 2025
57b4973
Update loss_utils.py
danielhanchen Mar 21, 2025
3cfa271
Update loss_utils.py
danielhanchen Mar 21, 2025
1f5b6f2
Update __init__.py
danielhanchen Mar 21, 2025
2f3c87b
fix: AsyncLLMEngine bugs (#82)
bradhilton Mar 22, 2025
64dd76c
fixed a typo in L119, removing unnecessary len() (#84)
SpaceHunterInf Mar 22, 2025
5a1a2b5
Merge branch 'main' into nightly
danielhanchen Mar 22, 2025
a62e4c6
Fix gradient checkpointing warning filter implementation
rolandtannous Mar 24, 2025
d115cea
Input grads fix for gemma3 (#96)
mmathew23 Mar 25, 2025
454757c
Merge pull request #97 from rolandtannous/fix/suppress-gradient-check…
shimmyshimmer Mar 25, 2025
c50123a
Update vision_utils.py
danielhanchen Mar 26, 2025
b199491
Vision requires grad
danielhanchen Mar 26, 2025
1670fa6
Check SDPA for Mistral / Pixtral
danielhanchen Mar 26, 2025
e32f797
Update compiler.py
danielhanchen Mar 26, 2025
b9d9cc5
Update vision_utils.py
danielhanchen Mar 26, 2025
5e3c88f
Update vision_utils.py
danielhanchen Mar 26, 2025
5c4086c
Update vision_utils.py
danielhanchen Mar 26, 2025
0599242
Update __init__.py
danielhanchen Mar 26, 2025
8da5939
Update vision_utils.py
danielhanchen Mar 26, 2025
db90dca
Update vision_utils.py
danielhanchen Mar 26, 2025
51cefe5
Update vision_utils.py
danielhanchen Mar 26, 2025
20b42ce
Update vision_utils.py
danielhanchen Mar 26, 2025
b03ded6
Update vision_utils.py
danielhanchen Mar 26, 2025
65469c2
Update vision_utils.py
danielhanchen Mar 26, 2025
7f4eb00
Update vision_utils.py
danielhanchen Mar 26, 2025
8584b5d
Update vision_utils.py
danielhanchen Mar 26, 2025
8bb6b55
Update vision_utils.py
danielhanchen Mar 26, 2025
20a61b0
Update vision_utils.py
danielhanchen Mar 26, 2025
86ca6d5
Update vision_utils.py
danielhanchen Mar 26, 2025
0221094
Update vision_utils.py
danielhanchen Mar 26, 2025
d13ebf7
Update vision_utils.py
danielhanchen Mar 26, 2025
f6c4b2e
Update vision_utils.py
danielhanchen Mar 26, 2025
2d1e506
Update vllm_utils.py (#99)
5k5000 Mar 26, 2025
23e018f
Update vision_utils.py
danielhanchen Mar 26, 2025
9f1eaa2
Fixes to support IterableDataset (#98)
marcandrelarochelle Mar 26, 2025
affb9d8
Merge branch 'nightly' of https://github.com/unslothai/unsloth-zoo in…
danielhanchen Mar 26, 2025
1efc541
Merge branch 'main' into nightly
danielhanchen Mar 26, 2025
6ae6d0e
Merge branch 'main' into nightly
danielhanchen May 6, 2025
8986b95
Update vllm_utils.py
danielhanchen May 6, 2025
d37ed39
Create vllm_rlhf_utils.py
danielhanchen May 6, 2025
abf388b
Update vllm_rlhf_utils.py
danielhanchen May 6, 2025
6ff1836
Update vllm_rlhf_utils.py
danielhanchen May 6, 2025
deea45f
Update vllm_rlhf_utils.py
danielhanchen May 6, 2025
4ae4dd6
Update vllm_rlhf_utils.py
danielhanchen May 7, 2025
8dcff39
Update vllm_rlhf_utils.py
danielhanchen May 7, 2025
6edc24c
Update vllm_rlhf_utils.py
danielhanchen May 7, 2025
e61ad9b
Update vllm_rlhf_utils.py
danielhanchen May 7, 2025
f74ba8e
Update vllm_rlhf_utils.py
danielhanchen May 8, 2025
ee5bb55
Update vllm_rlhf_utils.py
danielhanchen May 8, 2025
994533c
vLLM for Qwen 3
danielhanchen May 11, 2025
23d8aa7
Merge branch 'main' into nightly
danielhanchen May 11, 2025
569c3ae
Update vllm_utils.py
danielhanchen May 11, 2025
847ad07
Update vllm_utils.py
danielhanchen May 11, 2025
121ff4b
Merge branch 'main' into nightly
danielhanchen May 11, 2025
58d4084
Update vllm_utils.py
danielhanchen May 11, 2025
a1b4cbe
Update vllm_utils.py
danielhanchen May 11, 2025
18c96f1
Update vllm_utils.py
danielhanchen May 11, 2025
d38bdef
Update vllm_utils.py
danielhanchen May 11, 2025
0828b02
Update vllm_utils.py
danielhanchen May 11, 2025
6a2fa74
Update vllm_utils.py
danielhanchen May 11, 2025
53bc1f7
Update vllm_utils.py
danielhanchen May 11, 2025
1f39201
Update vllm_utils.py
danielhanchen May 11, 2025
8326902
Update vllm_utils.py
danielhanchen May 11, 2025
9782878
Update vllm_utils.py
danielhanchen May 11, 2025
344dcfa
Update vllm_utils.py
danielhanchen May 11, 2025
c836004
Update vllm_utils.py
danielhanchen May 11, 2025
97378a5
Update vllm_utils.py
danielhanchen May 11, 2025
2eb2947
Merge branch 'main' into nightly
danielhanchen May 13, 2025
2a6b1e0
Update vllm_utils.py
danielhanchen May 13, 2025
eec826b
Update vllm_utils.py
danielhanchen May 13, 2025
d1cbffc
Update vllm_utils.py
danielhanchen May 13, 2025
c977d58
Update vllm_utils.py
danielhanchen May 13, 2025
3c4d27b
Update vllm_utils.py
danielhanchen May 13, 2025
fa03be7
Update vllm_utils.py
danielhanchen May 13, 2025
996ce3e
Update vllm_utils.py
danielhanchen May 13, 2025
f8cb7fc
Update vllm_utils.py
danielhanchen May 13, 2025
e1170ac
Update vllm_utils.py
danielhanchen May 13, 2025
46d2e2a
Update vllm_utils.py
danielhanchen May 13, 2025
f5d80ad
Update vllm_utils.py
danielhanchen May 13, 2025
c6d1240
Update vllm_utils.py
danielhanchen May 13, 2025
77f6075
Update vllm_utils.py
danielhanchen May 13, 2025
880f2ca
Update vllm_utils.py
danielhanchen May 13, 2025
2f93f23
Update vllm_utils.py
danielhanchen May 13, 2025
b058b57
Update vllm_utils.py
danielhanchen May 13, 2025
53ed102
Update vllm_utils.py
danielhanchen May 13, 2025
105411b
Update compiler.py
danielhanchen May 13, 2025
df07dc7
Update vllm_utils.py
danielhanchen May 13, 2025
157eaa5
Update vllm_utils.py
danielhanchen May 13, 2025
e54bd82
Update vllm_utils.py
danielhanchen May 13, 2025
31cfc28
Update vllm_utils.py
danielhanchen May 13, 2025
f561d9c
Update vllm_utils.py
danielhanchen May 13, 2025
13e8de6
Update vllm_utils.py
danielhanchen May 13, 2025
a23600b
Update vllm_utils.py
danielhanchen May 13, 2025
d5efe50
Update vllm_utils.py
danielhanchen May 13, 2025
3a5c948
Update vllm_utils.py
danielhanchen May 13, 2025
8e4a0a9
Update vllm_utils.py
danielhanchen May 13, 2025
ebfaad5
Update vllm_utils.py
danielhanchen May 13, 2025
e1ed6c1
Update vllm_utils.py
danielhanchen May 13, 2025
1afbe08
Update vllm_utils.py
danielhanchen May 13, 2025
b888da6
Update vllm_utils.py
danielhanchen May 13, 2025
dafbfc1
Update rl_replacements.py
danielhanchen May 13, 2025
7d66758
Update rl_replacements.py
danielhanchen May 13, 2025
247eb72
Update rl_replacements.py
danielhanchen May 13, 2025
5340792
Update rl_replacements.py
danielhanchen May 13, 2025
f484ce4
Swap space reduce
danielhanchen May 13, 2025
2c85e6e
Update vllm_utils.py
danielhanchen May 13, 2025
29c51ee
Update vllm_utils.py
danielhanchen May 13, 2025
275e758
Update rl_replacements.py
danielhanchen May 13, 2025
2bed350
Update vllm_utils.py
danielhanchen May 13, 2025
24d059d
Update vllm_utils.py
danielhanchen May 13, 2025
616f64e
Update vllm_utils.py
danielhanchen May 13, 2025
1f0480f
Update vllm_utils.py
danielhanchen May 13, 2025
a30f24f
Update __init__.py
danielhanchen May 13, 2025
beed07d
Update rl_replacements.py
danielhanchen May 13, 2025
9836cc0
Merge branch 'main' into nightly
danielhanchen May 13, 2025
c4ef2e8
Merge branch 'main' into nightly
danielhanchen May 14, 2025
04f0c0a
Update vllm_utils.py
danielhanchen May 14, 2025
b45c15a
Update vllm_utils.py
danielhanchen May 14, 2025
1ef267a
Update vllm_utils.py
danielhanchen May 14, 2025
393a41e
Update vllm_utils.py
danielhanchen May 14, 2025
31ca23a
Update rl_replacements.py
danielhanchen May 14, 2025
3af611a
Update vllm_utils.py
danielhanchen May 14, 2025
c0a4022
Update rl_replacements.py
danielhanchen May 14, 2025
78c2ec6
Revert "Update rl_replacements.py"
danielhanchen May 14, 2025
c1626ad
Update __init__.py
danielhanchen May 14, 2025
1b85d8d
Update patching_utils.py
danielhanchen May 15, 2025
d9a0a97
Merge branch 'main' into nightly
danielhanchen May 15, 2025
0130ea3
Update compiler.py
danielhanchen May 15, 2025
c801467
Update compiler.py
danielhanchen May 15, 2025
a12ceb2
Update compiler.py
danielhanchen May 15, 2025
4f65c72
Update compiler.py
danielhanchen May 15, 2025
dfee5c4
Update compiler.py
danielhanchen May 15, 2025
fc5ed10
Update compiler.py
danielhanchen May 15, 2025
f6f714f
Update compiler.py
danielhanchen May 15, 2025
021e801
Update compiler.py
danielhanchen May 15, 2025
d3ea31d
Fixes
danielhanchen May 15, 2025
63b0b60
Update temporary_patches.py
danielhanchen May 15, 2025
c4ac369
Update temporary_patches.py
danielhanchen May 15, 2025
4804692
Update compiler.py
danielhanchen May 15, 2025
58a33f6
Update compiler.py
danielhanchen May 15, 2025
3f01f75
Update compiler.py
danielhanchen May 15, 2025
7feef84
Update compiler.py
danielhanchen May 15, 2025
588673c
Update compiler.py
danielhanchen May 15, 2025
9ac33e1
Update compiler.py
danielhanchen May 15, 2025
ec181c6
Update compiler.py
danielhanchen May 15, 2025
6afdfc0
Update temporary_patches.py
danielhanchen May 15, 2025
1d38ffc
Update temporary_patches.py
danielhanchen May 15, 2025
5512578
Update temporary_patches.py
danielhanchen May 15, 2025
0823a4e
Update compiler.py
danielhanchen May 15, 2025
8dba2e1
revert
danielhanchen May 15, 2025
5ad93e2
Update temporary_patches.py
danielhanchen May 15, 2025
9b25691
Update temporary_patches.py
danielhanchen May 15, 2025
cbf69fe
Update temporary_patches.py
danielhanchen May 15, 2025
62fc0b7
Update temporary_patches.py
danielhanchen May 15, 2025
35eb9dc
Update temporary_patches.py
danielhanchen May 15, 2025
d43302b
Update temporary_patches.py
danielhanchen May 15, 2025
8fab583
Update __init__.py
danielhanchen May 15, 2025
ed69af7
Update compiler.py
danielhanchen May 15, 2025
3ec4752
Update temporary_patches.py
danielhanchen May 15, 2025
1d10228
Update compiler.py
danielhanchen May 15, 2025
58fc919
Update temporary_patches.py
danielhanchen May 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion unsloth_zoo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

__version__ = "2025.5.5"
__version__ = "2025.5.6"

from importlib.util import find_spec
if find_spec("unsloth") is None:
Expand Down
58 changes: 50 additions & 8 deletions unsloth_zoo/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def create_standalone_class(

source = f"{compile}\n{source}\n"

left = re.match("[\s\n]{4,}", leftover).span()[1]
left = re.match(r"[\s\n]{4,}", leftover).span()[1]
new_forward = definition + leftover[:left] + \
f"return {module}_forward({parameters})\n"
full_class = full_class.replace(old_source, new_forward)
Expand All @@ -505,6 +505,9 @@ def create_standalone_class(
# Combine all into file
source = source + full_class

# Remove @auto_docstring
source = source.replace("@auto_docstring", "")

# Fix Gemma 3 ignore_index being not set!
source = source.replace("self.config.ignore_index", "-100")
return source
Expand Down Expand Up @@ -1470,18 +1473,45 @@ def unsloth_compile_transformers(
if hasattr(modeling_file, "__UNSLOTH_PATCHED__"): return

# Use transformers model_type logger to supress message: Remove `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`
exec("model_logger.addFilter(HideLoggingMessage('Setting `use_cache=False`'))", globals(), locals())
exec("model_logger.addFilter(HideLoggingMessage('`use_cache=True`'))", globals(), locals())

# Instead of Inductor Compilation:
try:
import torch._inductor.async_compile
from torch.hub import tqdm
def replaced_tqdm(*args, **kwargs):
kwargs["desc"] = "Unsloth: Compiling kernels"
return tqdm(*args, **kwargs)
torch._inductor.async_compile.tqdm = replaced_tqdm
except:
print("Unsloth: Failed editing tqdm to replace Inductor Compilation:")
pass

# torch_compile_options
UNSLOTH_COMPILE_DEBUG = os.environ.get("UNSLOTH_COMPILE_DEBUG", "0") == "1"
UNSLOTH_COMPILE_MAXIMUM = os.environ.get("UNSLOTH_COMPILE_MAXIMUM", "0") == "1"
UNSLOTH_COMPILE_IGNORE_ERRORS = os.environ.get("UNSLOTH_COMPILE_IGNORE_ERRORS", "0") == "1"
UNSLOTH_ENABLE_LOGGING = os.environ.get("UNSLOTH_ENABLE_LOGGING", "0") == "1"
torch_compile_options = {
"epilogue_fusion" : epilogue_fusion,
"max_autotune" : max_autotune,
"shape_padding" : shape_padding,
"trace.enabled" : UNSLOTH_COMPILE_DEBUG or debug,
"triton.cudagraphs" : cudagraphs,
"epilogue_fusion" : epilogue_fusion,
"max_autotune" : max_autotune,
"shape_padding" : shape_padding,
"trace.enabled" : UNSLOTH_COMPILE_DEBUG or debug,
"triton.cudagraphs" : cudagraphs,
"debug" : UNSLOTH_COMPILE_DEBUG or debug,
"dce" : True,
"memory_planning" : True,
"coordinate_descent_tuning" : UNSLOTH_COMPILE_MAXIMUM,
"trace.graph_diagram" : UNSLOTH_COMPILE_DEBUG or debug,
"compile_threads" : 24,
"combo_kernels" : False, # Causes incompatible gradient sizes on 2.6
"group_fusion" : True,
"disable_progress" : not UNSLOTH_ENABLE_LOGGING,
"verbose_progress" : UNSLOTH_ENABLE_LOGGING,
"triton.multi_kernel" : False, # Sometimes fails
"triton.use_block_ptr" : True,
"triton.enable_persistent_tma_matmul" : True,
"triton.autotune_at_compile_time" : True,
}

# Return logits
Expand Down Expand Up @@ -1705,6 +1735,18 @@ def unsloth_compile_transformers(
bad_torch_modules.add(module)
pass

# Remove decoder layers
if "for layer in self." in source:
print(f"Unsloth: Failed compiling function {module} since it looks like a decoder!")
bad_torch_modules.add(module)
pass

# Remove padding
if "nn.functional.pad" in source or "padding" in source:
print(f"Unsloth: Failed compiling function {module} since there is padding done.")
bad_torch_modules.add(module)
pass

# Check for residual streams optimizations
if fast_residual_stream and "residual" in source:
new_source = patch_residual_stream(source)
Expand Down Expand Up @@ -1790,7 +1832,7 @@ def unsloth_compile_transformers(
# Remove causal masks
do_not_remove = False
for module in remove_causal_masks:
if module.endswith(("ForConditionalGeneration")):
if module.endswith(("ForConditionalGeneration", "Gemma3Model")):
do_not_remove = True
print(f"Unsloth: Will not remove causal mask for {model_location} since it's a VLM!")
break
Expand Down
51 changes: 28 additions & 23 deletions unsloth_zoo/patching_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def patch_model_and_tokenizer(
downcast_rope = True,
fix_embeddings = True,
do_forced_float32 = False,
correct_dtype = None,
):
# All Unsloth Zoo code licensed under LGPLv3
assert(type(downcast_rope) is bool)
Expand Down Expand Up @@ -223,10 +224,12 @@ def patch_model_and_tokenizer(
pass

# Get most likely the correct data-type of the model
try:
correct_dtype = _get_dtype(model.config.torch_dtype)
except:
correct_dtype = model.get_input_embeddings().weight.dtype
if correct_dtype is None:
try:
correct_dtype = _get_dtype(model.config.torch_dtype)
except:
correct_dtype = model.get_input_embeddings().weight.dtype
pass
# If we force float32, we first use bfloat16, then downcast to float16
if do_forced_float32:
correct_dtype = torch.float16
Expand All @@ -242,29 +245,31 @@ def patch_model_and_tokenizer(
assert(module.weight.dtype == torch.float32)
torch.cuda.empty_cache()
pass
pass

# Correct torch_dtype
def __fix_dtype(config):
if not hasattr(config, "to_dict"): return
dicts = config.to_dict()
for key, value in dicts.items():
if key == "torch_dtype":
setattr(config, "torch_dtype", torch.float16)
else:
__fix_dtype(getattr(config, key))
m = model
while hasattr(m, "model"):
if hasattr(m, "dtype"):
try: setattr(m, "dtype", torch.float16)
except: pass
if hasattr(m, "config"): __fix_dtype(m.config)
m = m.model
pass
if hasattr(m, "config"): __fix_dtype(m.config)
# Correct torch_dtype
def __fix_dtype(config):
if not hasattr(config, "to_dict"): return
dicts = config.to_dict()
for key, value in dicts.items():
if key == "torch_dtype":
setattr(config, "torch_dtype", correct_dtype)
else:
__fix_dtype(getattr(config, key))
m = model
while hasattr(m, "model"):
if hasattr(m, "dtype"):
try: setattr(m, "dtype", torch.float16)
try: setattr(m, "dtype", correct_dtype)
except: pass
if hasattr(m, "config"): __fix_dtype(m.config)
m = m.model
pass
if hasattr(m, "config"): __fix_dtype(m.config)
if hasattr(m, "dtype"):
try: setattr(m, "dtype", correct_dtype)
except: pass
pass

# Check all params and patch!
for name, module in model.named_modules():
if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)):
Expand Down
122 changes: 114 additions & 8 deletions unsloth_zoo/temporary_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Union, List, Any, Tuple, Dict, Callable, Optional
import inspect
import torch
import torch.nn
import os
import logging

Expand Down Expand Up @@ -272,6 +273,7 @@ def forward(
**lm_kwargs,
)
labels = None
# We NEVER ENTER if labels is not None: since we already accounted for it


logits = outputs.logits
Expand Down Expand Up @@ -307,13 +309,109 @@ def forward(
image_hidden_states=image_features if pixel_values is not None else None,
)
pass

old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration.forward).parameters
new_keys = inspect.signature(forward).parameters
if old_keys != new_keys:
pass
else:
transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration.forward = forward
return

def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
token_type_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**lm_kwargs,
) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None and attention_mask is not None:
attention_mask = attention_mask.to(device = labels.device)
labels[attention_mask == 0] = -100
pass
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**lm_kwargs,
)
labels = None
# We NEVER ENTER if labels is not None: since we already accounted for it

hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])

loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
shift_logits = logits[..., :-1, :]
shift_labels = labels[..., 1:]
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
else:
shift_logits = shift_logits.contiguous()
shift_labels = shift_labels.contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()

flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
flat_labels = shift_labels.view(-1).to(shift_logits.device)
loss = loss_fct(flat_logits, flat_labels)
loss = outputs.loss

if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output

return Gemma3CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=outputs.image_hidden_states,
)
pass

old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration.forward).parameters
new_keys = inspect.signature(forward).parameters
if old_keys != new_keys:
print("Unsloth: Failed to patch Gemma3ForConditionalGeneration.")
else:
transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration.forward = forward
return
pass
TEMPORARY_PATCHES.append(patch_Gemma3ForConditionalGeneration)

Expand Down Expand Up @@ -395,12 +493,21 @@ def _update_causal_mask(

return causal_mask
pass
old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration._update_causal_mask).parameters
new_keys = inspect.signature(_update_causal_mask).parameters
if old_keys != new_keys:
print("Unsloth: Failed to patch Gemma3ForConditionalGeneration.")

if hasattr(transformers.models.gemma3.modeling_gemma3, "Gemma3Model"):
old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3Model._update_causal_mask).parameters
new_keys = inspect.signature(_update_causal_mask).parameters
if old_keys != new_keys:
print("Unsloth: Failed to patch Gemma3Model.")
else:
transformers.models.gemma3.modeling_gemma3.Gemma3Model._update_causal_mask = _update_causal_mask
else:
transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration._update_causal_mask = _update_causal_mask
old_keys = inspect.signature(transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration._update_causal_mask).parameters
new_keys = inspect.signature(_update_causal_mask).parameters
if old_keys != new_keys:
print("Unsloth: Failed to patch Gemma3ForConditionalGeneration._update_causal_mask.")
else:
transformers.models.gemma3.modeling_gemma3.Gemma3ForConditionalGeneration._update_causal_mask = _update_causal_mask
return
pass
TEMPORARY_PATCHES.append(patch_Gemma3ForConditionalGeneration_causal_mask)
Expand Down Expand Up @@ -583,7 +690,6 @@ def patch_SmolVLMForConditionalGeneration_forward():
from typing import List, Optional, Tuple, Union

from transformers.models.smolvlm.modeling_smolvlm import (
CrossEntropyLoss,
SmolVLMCausalLMOutputWithPast,
SmolVLMForConditionalGeneration,
)
Expand Down Expand Up @@ -675,7 +781,7 @@ def forward(
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1).to(
Expand Down