From a8115330de9a4a811de135f73222302012d9a670 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 00:23:26 -0700 Subject: [PATCH 01/22] Update __init__.py --- unsloth_zoo/__init__.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/__init__.py b/unsloth_zoo/__init__.py index 169e1f41f..7f0c2b658 100644 --- a/unsloth_zoo/__init__.py +++ b/unsloth_zoo/__init__.py @@ -4,4 +4,24 @@ # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at https://mozilla.org/MPL/2.0/. -__version__ = "2024.10.5" \ No newline at end of file +__version__ = "2024.11.1" + +import importlib.util +if importlib.util.find_spec("unsloth") is None: + raise ImportError("Please install Unsloth via `pip install unsloth`!") +pass +del importlib.util + +import os +if not ("UNSLOTH_IS_PRESENT" in os.environ): + raise ImportError("Please install Unsloth via `pip install unsloth`!") +pass + +try: + print("🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.") +except: + print("Unsloth: Will patch your computer to enable 2x faster free finetuning.") +pass +# Log Unsloth-Zoo Utilities +os.environ["UNSLOTH_ZOO_IS_PRESENT"] = "1" +del os \ No newline at end of file From 37c7074ee80ba757b5f6c36b0e8cf823a656586b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 00:23:43 -0700 Subject: [PATCH 02/22] Update __init__.py --- unsloth_zoo/__init__.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/unsloth_zoo/__init__.py b/unsloth_zoo/__init__.py index fca09740c..7f0c2b658 100644 --- a/unsloth_zoo/__init__.py +++ b/unsloth_zoo/__init__.py @@ -1,18 +1,8 @@ -# Unsloth Zoo - Utilities for Unsloth -# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# Copyright (C) 2024-present the Unsloth AI team. All rights reserved. # -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with this program. If not, see . +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at https://mozilla.org/MPL/2.0/. __version__ = "2024.11.1" @@ -26,4 +16,12 @@ if not ("UNSLOTH_IS_PRESENT" in os.environ): raise ImportError("Please install Unsloth via `pip install unsloth`!") pass -del os + +try: + print("🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.") +except: + print("Unsloth: Will patch your computer to enable 2x faster free finetuning.") +pass +# Log Unsloth-Zoo Utilities +os.environ["UNSLOTH_ZOO_IS_PRESENT"] = "1" +del os \ No newline at end of file From b200f44be71d545497f582467a2c38e08f216721 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 00:24:04 -0700 Subject: [PATCH 03/22] Update __init__.py --- unsloth_zoo/__init__.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/unsloth_zoo/__init__.py b/unsloth_zoo/__init__.py index 7f0c2b658..c7bbcc5d0 100644 --- a/unsloth_zoo/__init__.py +++ b/unsloth_zoo/__init__.py @@ -1,8 +1,18 @@ -# Copyright (C) 2024-present the Unsloth AI team. All rights reserved. +# Unsloth Zoo - Utilities for Unsloth +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. # -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at https://mozilla.org/MPL/2.0/. +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see . __version__ = "2024.11.1" From 8c379f94372229f6932879650b1576cffa77c81d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 01:18:21 -0700 Subject: [PATCH 04/22] Create patching_utils.py --- unsloth_zoo/patching_utils.py | 90 +++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 unsloth_zoo/patching_utils.py diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py new file mode 100644 index 000000000..a2572c3c1 --- /dev/null +++ b/unsloth_zoo/patching_utils.py @@ -0,0 +1,90 @@ +# Unsloth Zoo - Utilities for Unsloth +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see . + +import torch + +__all__ = [ + "patch_compiling_bitsandbytes", + "patch_layernorm", + "patch_torch_compile", +] + +# Also disable compiling on bitsandbytes +def patch_compiling_bitsandbytes(): + # import peft.tuners.lora.bnb + # peft.tuners.lora.bnb.Linear4bit.forward = \ + # torch._disable_dynamo(peft.tuners.lora.bnb.Linear4bit.forward) + # peft.tuners.lora.bnb.Linear8bitLt.forward = \ + # torch._disable_dynamo(peft.tuners.lora.bnb.Linear8bitLt.forward) + # return + import bitsandbytes.nn.modules + bitsandbytes.nn.modules.Linear4bit.forward = \ + torch._disable_dynamo(bitsandbytes.nn.modules.Linear4bit.forward) + return +pass + + +def patch_layernorm(fast_layernorm): + import torch.nn + if torch.nn.LayerNorm.__name__ != "Unsloth_LayerNorm": + + from torch.nn import LayerNorm + class Unsloth_LayerNorm(LayerNorm): + def forward(self, X): + return fast_layernorm(self, X) + pass + pass + + torch.nn.LayerNorm = Unsloth_LayerNorm + return +pass + + +def patch_torch_compile(debug = True): + assert(type(debug) is bool) + # Torch compile arguments + torch_compile_arguments = [ + "config.dce = True", + "config.memory_planning = True", + "config.memory_pool = 'combined'", + "config.coordinate_descent_tuning = True", + "config.max_autotune_gemm = False", # GEMM is unnecessary + "config.autotune_multi_device = False", + "config.max_autotune_gemm_backends = 'TRITON,ATEN,CPP'", # Not much faster + "config.aggressive_fusion = False", # Careful changes results! + "config.cuda.enable_cuda_lto = True", + "config.cuda.use_fast_math = True", + "config.cuda.compile_opt_level = '-O2'", + ] + # Torch dynamo arguments + torch_dynamo_arguments = [ + "config.accumulated_cache_size_limit = 1024", # Bump up a bit from 256 + f"config.suppress_errors = {not debug}", # Supress errors for now + f"config.do_not_emit_runtime_asserts = {not debug}", + "config.cache_size_limit = 1024", # Flex Attention + "config.inline_inbuilt_nn_modules = True", # Torch 2.5 Regional recompilation + ] + import torch._inductor.config as config + for _try_compile_argument in torch_compile_arguments: + try: exec(_try_compile_argument) + except: pass + pass + import torch._dynamo.config as config + for _try_dynamo_argument in torch_dynamo_arguments: + try: exec(_try_dynamo_argument) + except: pass + pass +pass From fa836975a84146891f80c230aa34101b89639f1a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 15:16:18 -0800 Subject: [PATCH 05/22] Bug fixes --- unsloth_zoo/loss_utils.py | 78 +++++++++++++++++++--------------- unsloth_zoo/patching_utils.py | 21 +++++++++ unsloth_zoo/tokenizer_utils.py | 31 ++++++++++++-- 3 files changed, 93 insertions(+), 37 deletions(-) diff --git a/unsloth_zoo/loss_utils.py b/unsloth_zoo/loss_utils.py index 5fd4d4187..756da0a57 100644 --- a/unsloth_zoo/loss_utils.py +++ b/unsloth_zoo/loss_utils.py @@ -16,63 +16,73 @@ import torch from packaging.version import Version +torch_nn_functional_cross_entropy = torch.nn.functional.cross_entropy __all__ = [ - "causal_loss_function", - "transformers_losses_patcher", - "patch_loss_function", + "patch_loss_functions", + "post_patch_loss_function", ] -def causal_loss_function(_fast_cross_entropy_loss): +def patch_loss_functions(_fast_cross_entropy_loss): + try: + import transformers.loss.loss_utils + except: + print("Unsloth: Cannot patch loss functions - update transformers for faster modules!") + return None + pass + + # Generic cross entropy loss + def unsloth_fixed_cross_entropy(source, target, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs): + if ignore_index == -100: + loss = _fast_cross_entropy_loss( + logits = source, + labels = target, + n_items = num_items_in_batch, + ) + else: + reduction = "sum" if num_items_in_batch is not None else "mean" + loss = torch_nn_functional_cross_entropy( + source, + target, + ignore_index = ignore_index, + reduction = reduction, + ) + if reduction == "sum": loss = loss / num_items_in_batch + return loss + pass + + # Causal LM loss def UnslothForCausalLMLoss( logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs ): shift_logits = logits shift_labels = torch.empty_like(labels) shift_labels[..., :-1] = labels[..., 1:] - shift_labels[..., -1] = -100 - loss = _fast_cross_entropy_loss( - logits = shift_logits, - labels = shift_labels, - n_items = num_items_in_batch, - ) + shift_labels[..., -1] = ignore_index + loss = unsloth_fixed_cross_entropy(shift_logits, shift_labels, num_items_in_batch, ignore_index, **kwargs) return loss pass if (Version(torch.__version__) < Version("2.4.0")): UnslothForCausalLMLoss = torch._disable_dynamo(UnslothForCausalLMLoss) pass - return UnslothForCausalLMLoss -pass - -def transformers_losses_patcher(UnslothForCausalLMLoss): - def _patch_transformers_losses(): - import re - try: - import transformers.loss.loss_utils - except: - print("Unsloth: Cannot patch loss functions - update transformers for faster modules!") - return - pass + # Now patch the losses! + import transformers.modeling_utils + LOSS_MAPPING = transformers.loss.loss_utils.LOSS_MAPPING + LOSS_MAPPING["ForCausalLM"] = UnslothForCausalLMLoss - import transformers.modeling_utils - LOSS_MAPPING = transformers.loss.loss_utils.LOSS_MAPPING - LOSS_MAPPING["ForCausalLM"] = UnslothForCausalLMLoss - - # Remove @property and @lru_cache - if hasattr(transformers.modeling_utils.PreTrainedModel.loss_function, "fget"): - transformers.modeling_utils.PreTrainedModel.loss_function = \ - transformers.modeling_utils.PreTrainedModel.loss_function.fget.__wrapped__ - pass - print("Unsloth: Patched cross entropy losses.") + # Remove @property and @lru_cache + if hasattr(transformers.modeling_utils.PreTrainedModel.loss_function, "fget"): + transformers.modeling_utils.PreTrainedModel.loss_function = \ + transformers.modeling_utils.PreTrainedModel.loss_function.fget.__wrapped__ pass - return _patch_transformers_losses + print("Unsloth: Patched cross entropy losses.") pass -def patch_loss_function(model): +def post_patch_loss_function(model): try: # model.loss_function starts as a dict to a loss fx # We invoke it to save it diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index a2572c3c1..802b861ad 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -20,6 +20,7 @@ "patch_compiling_bitsandbytes", "patch_layernorm", "patch_torch_compile", + "patch_regional_compilation", ] # Also disable compiling on bitsandbytes @@ -88,3 +89,23 @@ def patch_torch_compile(debug = True): except: pass pass pass + + +def patch_regional_compilation(): + # Regional torch 2.5 Recompilation - weirdly very slow?? + if torch.nn.ModuleList.__name__ == "UnslothModuleList": return + # Only works for torch 2.5 + if Version(torch.__version__) < Version("2.5.0"): return + + old_module_list = torch.nn.ModuleList + + def UnslothModuleList(*args, **kwargs): + if len(args) == 1 and len(kwargs) == 0 and type(args[0]) is list: + args = [old_module_list([torch.compile(x, dynamic = True, options = torch_compile_options, fullgraph = False) for x in args[0]])] + return old_module_list(*args, **kwargs) + pass + UnslothModuleList.__doc__ = old_module_list.__doc__ + + torch.nn.ModuleList = UnslothModuleList + return +pass diff --git a/unsloth_zoo/tokenizer_utils.py b/unsloth_zoo/tokenizer_utils.py index 5fcc4a26d..50af5503b 100644 --- a/unsloth_zoo/tokenizer_utils.py +++ b/unsloth_zoo/tokenizer_utils.py @@ -104,10 +104,16 @@ def add_new_tokens( mean_lm_head = mean_lm_head .to(torch.float32) # Get old lengths - old_input_length = model.get_input_embeddings ().weight.shape[0] - old_output_length = model.get_output_embeddings().weight.shape[0] + old_input_embedding = model.get_input_embeddings ().weight + old_output_embedding = model.get_output_embeddings().weight + old_input_length = old_input_embedding .shape[0] + old_output_length = old_output_embedding.shape[0] old_config_size = model.config.vocab_size + # Check for tied weights as well + is_tied = (old_input_embedding.data_ptr() == old_output_embedding.data_ptr()) \ + or (model.config.tie_word_embeddings) + # Add tokens! old_length = len(tokenizer) tokenizer.add_tokens(new_tokens) @@ -165,7 +171,26 @@ def add_new_tokens( internal_model = internal_model.model pass internal_model._need_to_train_embeddings = True - + + # Fix up all vocab sizes + current_model = model + while hasattr(current_model, "model") and hasattr(current_model, "config"): + if hasattr(current_model.config, "vocab_size"): + current_model.config.update({"vocab_size" : len(tokenizer)}) + current_model = current_model.model + if hasattr(current_model, "model") and hasattr(current_model, "config"): + if hasattr(current_model.config, "vocab_size"): + current_model.config.update({"vocab_size" : len(tokenizer)}) + pass + + # Must tie lm_head and embed_tokens if they are tied! + # Otherwise error will occur on saving models ie use save_model + if is_tied: model.tie_weights() + + # Clear deleted GPU items + for _ in range(3): + gc.collect() + torch.cuda.empty_cache() return pass From 54e0c6400829cf376a344450bbaa985ec0a51a6a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 15:19:20 -0800 Subject: [PATCH 06/22] Update pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 01a610c21..2c9b740d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ dependencies = [ "triton", "packaging", "tyro", - "transformers>=4.44.2", + "transformers>=4.46.1", "datasets>=2.16.0", "sentencepiece>=0.2.0", "tqdm", From 5178b392dcda7381fdca958e2ac71f71346b5e8d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 15:22:36 -0800 Subject: [PATCH 07/22] Update __init__.py --- unsloth_zoo/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth_zoo/__init__.py b/unsloth_zoo/__init__.py index c7bbcc5d0..6b2f350bd 100644 --- a/unsloth_zoo/__init__.py +++ b/unsloth_zoo/__init__.py @@ -16,11 +16,11 @@ __version__ = "2024.11.1" -import importlib.util -if importlib.util.find_spec("unsloth") is None: +from importlib.util import find_spec +if find_spec("unsloth") is None: raise ImportError("Please install Unsloth via `pip install unsloth`!") pass -del importlib.util +del find_spec import os if not ("UNSLOTH_IS_PRESENT" in os.environ): @@ -34,4 +34,4 @@ pass # Log Unsloth-Zoo Utilities os.environ["UNSLOTH_ZOO_IS_PRESENT"] = "1" -del os \ No newline at end of file +del os From ac18419227e77737134d471e7c39a88d28cb2738 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 17:10:00 -0800 Subject: [PATCH 08/22] O3 --- unsloth_zoo/gradient_checkpointing.py | 2 +- unsloth_zoo/patching_utils.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index 768068146..54f7d815e 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -167,7 +167,7 @@ def unsloth_offloaded_gradient_checkpoint(function, *args, use_reentrant = None, def patch_gradient_checkpointing(): - print("Unsloth: Patching Gradient Checkpointing with Unsloth's special version!") + print("Unsloth: Patched gradient checkpointing for long context finetuning.") import torch.utils if torch.utils.checkpoint.checkpoint.__name__ == "unsloth_offloaded_gradient_checkpoint": return torch.utils.checkpoint._old_checkpoint = torch.utils.checkpoint.checkpoint diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index 802b861ad..430a456c9 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -54,18 +54,19 @@ def forward(self, X): pass -def patch_torch_compile(debug = True): +def patch_torch_compile(debug = True, O3 = False): assert(type(debug) is bool) + assert(type(O3) is bool) # Torch compile arguments torch_compile_arguments = [ "config.dce = True", "config.memory_planning = True", "config.memory_pool = 'combined'", "config.coordinate_descent_tuning = True", - "config.max_autotune_gemm = False", # GEMM is unnecessary + f"config.max_autotune_gemm = {O3}", # GEMM is unnecessary "config.autotune_multi_device = False", "config.max_autotune_gemm_backends = 'TRITON,ATEN,CPP'", # Not much faster - "config.aggressive_fusion = False", # Careful changes results! + f"config.aggressive_fusion = {O3}", # Careful changes results! "config.cuda.enable_cuda_lto = True", "config.cuda.use_fast_math = True", "config.cuda.compile_opt_level = '-O2'", From 8b3aade379c69fda4f1e7e9367ba2947bb35267c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 17:15:21 -0800 Subject: [PATCH 09/22] Update tokenizer_utils.py --- unsloth_zoo/tokenizer_utils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/unsloth_zoo/tokenizer_utils.py b/unsloth_zoo/tokenizer_utils.py index 50af5503b..08ddf1d8d 100644 --- a/unsloth_zoo/tokenizer_utils.py +++ b/unsloth_zoo/tokenizer_utils.py @@ -406,9 +406,6 @@ def patch_tokenizer(model, tokenizer): joiner = "\1\0=+=\0\1" number_repetitions = 3 - 1 # Number of reserved tokens needed - if model is not None: - model.config.update({"unsloth_version" : __version__}) - bad_pad_token = False if hasattr(tokenizer, "pad_token") and tokenizer.pad_token is not None: # Check if pad_token is not the same as eos_token otherwise the loss will ignore it!! From dc79f8d6806c42901e1fc2b6a99b6ffa63ad1f9f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 17:49:17 -0800 Subject: [PATCH 10/22] Post patch --- unsloth_zoo/patching_utils.py | 144 ++++++++++++++++++++++++++++++++++ 1 file changed, 144 insertions(+) diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index 430a456c9..80a7aa4c9 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -21,6 +21,7 @@ "patch_layernorm", "patch_torch_compile", "patch_regional_compilation", + "patch_model_and_tokenizer", ] # Also disable compiling on bitsandbytes @@ -110,3 +111,146 @@ def UnslothModuleList(*args, **kwargs): torch.nn.ModuleList = UnslothModuleList return pass + + +def patch_model_and_tokenizer(model, tokenizer): + import gc + + # Torch.compile fails on embedding matrix?? + try: old_input_embedding = model.get_input_embeddings ().weight + except: return model, tokenizer + + # Maybe not all models have a lm_head? + try: old_output_embedding = model.get_output_embeddings().weight + except: old_output_embedding = torch.zeros(0) + + # Check for tied weights as well + is_tied = (old_input_embedding.data_ptr() == old_output_embedding.data_ptr()) \ + or (model.config.tie_word_embeddings) + + # Check pad token's id -> we need to expand the embedding + if len(tokenizer) > old_input_embedding.shape[0]: + # Workaround randomnly fixes it for torch versions < 2. + requires_grad = old_input_embedding.requires_grad + old_input_embedding.requires_grad_(False) + old_input_embedding.resize_(len(tokenizer), old_input_embedding.shape[1]) + old_input_embedding.requires_grad_(requires_grad) + + # Fix up all vocab sizes + current_model = model + while hasattr(current_model, "model") and hasattr(current_model, "config"): + if hasattr(current_model.config, "vocab_size"): + current_model.config.update({"vocab_size" : len(tokenizer)}) + current_model.update({"unsloth_optimized" : True}) + current_model = current_model.model + if hasattr(current_model, "model") and hasattr(current_model, "config"): + if hasattr(current_model.config, "vocab_size"): + current_model.config.update({"vocab_size" : len(tokenizer)}) + current_model.update({"unsloth_optimized" : True}) + pass + pass + + model.set_input_embeddings( + torch.nn.Embedding.from_pretrained( + old_input_embedding, + padding_idx = getattr(model.config, "pad_token_id", None), + ) + ) + + # We also do this for the lm_head + if old_output_embedding.numel() != 0: + + requires_grad = old_output_embedding.requires_grad + lm_head = torch.nn.Linear(1, 1, bias = None) + del lm_head.weight + + lm_head.weight = old_output_embedding if not is_tied else old_input_embedding + lm_head.in_features = lm_head.weight.shape[1] + lm_head.out_features = lm_head.weight.shape[0] + + lm_head.weight.requires_grad_(requires_grad) + model.set_output_embeddings(lm_head) + if hasattr(model, "lm_head"): model.lm_head = lm_head + + correct_dtype = lm_head.weight.dtype + else: + correct_dtype = old_input_embedding.dtype + pass + + # Must tie lm_head and embed_tokens if they are tied! + # Otherwise error will occur on saving models ie use save_model + if is_tied: model.tie_weights() + + # Also fix torch_dtype + internal_model = model + while hasattr(internal_model, "model"): + if hasattr(internal_model, "config"): + if internal_model.config.torch_dtype == "float32": + internal_model.config.torch_dtype = torch.float32 + elif internal_model.config.torch_dtype == "bfloat16": + internal_model.config.torch_dtype = torch.bfloat16 + elif internal_model.config.torch_dtype == "float16": + internal_model.config.torch_dtype = torch.float16 + pass + pass + internal_model = internal_model.model + pass + if hasattr(internal_model, "config"): + if internal_model.config.torch_dtype == "float32": + internal_model.config.torch_dtype = torch.float32 + elif internal_model.config.torch_dtype == "bfloat16": + internal_model.config.torch_dtype = torch.bfloat16 + elif internal_model.config.torch_dtype == "float16": + internal_model.config.torch_dtype = torch.float16 + pass + pass + + # Also patch all dtypes - BnB seems to not allocate the correct type? + # BnB default dtype seems to be float16! + try: + from bitsandbytes.nn import Linear4bit as Bnb_Linear4bit + except: + raise ImportError("Unsloth: Please install bitsandbytes via `pip install bitsandbytes`") + try: + from peft.tuners.lora import Linear4bit as Peft_Linear4bit + except: + raise ImportError("Unsloth: Please install peft via `pip install peft`") + pass + + for name, module in model.named_modules(): + if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)): + weight = module.weight + quant_state = weight.quant_state + + if type(quant_state) is list: + # BnB seems to have float16 as default! + module.weight.quant_state[2] = correct_dtype # Cast to correct dtype + else: + # https://github.com/TimDettmers/bitsandbytes/pull/763/files + quant_state.dtype = correct_dtype + pass + pass + # Downcast RoPE embedding to correct data type + if downcast_rope and ((name.endswith("rotary_emb") or hasattr(module, "cos_cached"))): + + if hasattr(module, "cos_cached") and \ + (module.cos_cached.dtype != correct_dtype): + + module.cos_cached = module.cos_cached.to(correct_dtype) + module.sin_cached = module.sin_cached.to(correct_dtype) + + elif hasattr(module, "short_cos_cached") and \ + (module.short_cos_cached.dtype != correct_dtype): + + module.short_cos_cached = module.short_cos_cached.to(correct_dtype) + module.short_sin_cached = module.short_sin_cached.to(correct_dtype) + pass + pass + pass + + # Clear deleted GPU items + for _ in range(3): + gc.collect() + torch.cuda.empty_cache() + return model, tokenizer +pass From 8a27007a9a15ee17a1ef8c09e907492931648f58 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 17:58:18 -0800 Subject: [PATCH 11/22] Update patching_utils.py --- unsloth_zoo/patching_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index 80a7aa4c9..2da87cc65 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -113,7 +113,8 @@ def UnslothModuleList(*args, **kwargs): pass -def patch_model_and_tokenizer(model, tokenizer): +def patch_model_and_tokenizer(model, tokenizer, downcast_rope = True): + assert(type(downcast_rope) is bool) import gc # Torch.compile fails on embedding matrix?? From 589697ada1dc778d6b427212fc954150ca681cf3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 23:41:50 -0800 Subject: [PATCH 12/22] Update patching_utils.py --- unsloth_zoo/patching_utils.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index 2da87cc65..25dd44f0d 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -80,16 +80,16 @@ def patch_torch_compile(debug = True, O3 = False): "config.cache_size_limit = 1024", # Flex Attention "config.inline_inbuilt_nn_modules = True", # Torch 2.5 Regional recompilation ] - import torch._inductor.config as config - for _try_compile_argument in torch_compile_arguments: - try: exec(_try_compile_argument) - except: pass - pass - import torch._dynamo.config as config - for _try_dynamo_argument in torch_dynamo_arguments: - try: exec(_try_dynamo_argument) - except: pass - pass + # import torch._inductor.config as config + # for _try_compile_argument in torch_compile_arguments: + # try: exec(_try_compile_argument) + # except: pass + # pass + # import torch._dynamo.config as config + # for _try_dynamo_argument in torch_dynamo_arguments: + # try: exec(_try_dynamo_argument) + # except: pass + # pass pass From ac1334e360cda605a84983cffd32393597c3bb02 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 Nov 2024 01:20:34 -0800 Subject: [PATCH 13/22] Update gradient_checkpointing.py --- unsloth_zoo/gradient_checkpointing.py | 65 ++++++++++++++++++++++++++- 1 file changed, 64 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index 54f7d815e..4e256c44e 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -24,6 +24,11 @@ "prepare_n_gradient_checkpoints", "Unsloth_Offloaded_Gradient_Checkpointer", "unsloth_offloaded_gradient_checkpoint", + "patch_unsloth_gradient_checkpointing", + "unpatch_unsloth_gradient_checkpointing", + + "Unsloth_Gradient_Checkpointer", + "unsloth_gradient_checkpoint", "patch_gradient_checkpointing", "unpatch_gradient_checkpointing", ] @@ -155,6 +160,35 @@ def backward(ctx, dY): pass +class Unsloth_Gradient_Checkpointer(torch.autograd.Function): + """ + Same as normal gradient checkpointing but cleaner + """ + @staticmethod + @torch_amp_custom_fwd + def forward(ctx, forward_function, hidden_states, *args): + with torch.no_grad(): + output = forward_function(hidden_states, *args) + ctx.save_for_backward(hidden_states) + ctx.forward_function = forward_function + ctx.args = args + return output + pass + + @staticmethod + @torch_amp_custom_bwd + def backward(ctx, dY): + (hidden_states,) = ctx.saved_tensors + hidden_states = hidden_states.detach() + hidden_states.requires_grad_(True) + with torch.enable_grad(): + (output,) = ctx.forward_function(hidden_states, *ctx.args) + torch.autograd.backward(output, dY) + return (None, hidden_states.grad,) + (None,)*len(ctx.args) + pass +pass + + def unsloth_offloaded_gradient_checkpoint(function, *args, use_reentrant = None, **kwargs): return Unsloth_Offloaded_Gradient_Checkpointer.apply(function, *args) pass @@ -166,7 +200,18 @@ def unsloth_offloaded_gradient_checkpoint(function, *args, use_reentrant = None, pass -def patch_gradient_checkpointing(): +def unsloth_gradient_checkpoint(function, *args, use_reentrant = None, **kwargs): + return Unsloth_Gradient_Checkpointer.apply(function, *args) +pass +if (Version(torch.__version__) < Version("2.4.0")) and \ + not hasattr(unsloth_gradient_checkpoint, "__wrapped__"): + unsloth_gradient_checkpoint = torch._disable_dynamo( + unsloth_gradient_checkpoint + ) +pass + + +def patch_unsloth_gradient_checkpointing(): print("Unsloth: Patched gradient checkpointing for long context finetuning.") import torch.utils if torch.utils.checkpoint.checkpoint.__name__ == "unsloth_offloaded_gradient_checkpoint": return @@ -175,6 +220,24 @@ def patch_gradient_checkpointing(): pass +def patch_gradient_checkpointing(): + print("Unsloth: Patched gradient checkpointing.") + import torch.utils + if torch.utils.checkpoint.checkpoint.__name__ == "unsloth_gradient_checkpoint": return + torch.utils.checkpoint._old_checkpoint = torch.utils.checkpoint.checkpoint + torch.utils.checkpoint.checkpoint = unsloth_gradient_checkpoint +pass + + +def unpatch_unsloth_gradient_checkpointing(): + import torch.utils + if hasattr(torch.utils.checkpoint, "_old_checkpoint"): + torch.utils.checkpoint.checkpoint = torch.utils.checkpoint._old_checkpoint + del torch.utils.checkpoint._old_checkpoint + pass +pass + + def unpatch_gradient_checkpointing(): import torch.utils if hasattr(torch.utils.checkpoint, "_old_checkpoint"): From a6892c70bf30cb631068df236de03f70aee452bb Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 Nov 2024 17:56:20 -0800 Subject: [PATCH 14/22] Update patching_utils.py --- unsloth_zoo/patching_utils.py | 38 ++++++++++++++++++++++++----------- 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index 25dd44f0d..776e57f74 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -60,14 +60,25 @@ def patch_torch_compile(debug = True, O3 = False): assert(type(O3) is bool) # Torch compile arguments torch_compile_arguments = [ + f"config.debug = {debug}", "config.dce = True", "config.memory_planning = True", "config.memory_pool = 'combined'", - "config.coordinate_descent_tuning = True", + "config.efficient_conv_bn_eval_fx_passes = True", # Reduces stability a little bit + "config.dynamic_scale_rblock = True", # Scale down RBLOCK for better occupancy + "config.reorder_for_compute_comm_overlap = True", # # enable reordering pass for increasing overlap between compute and communication + f"config.max_autotune = {O3}", # enable slow autotuning passes to select algorithms + f"config.max_autotune_pointwise = {O3}", # enable slow autotuning passes to select pointwise/reductions algorithms f"config.max_autotune_gemm = {O3}", # GEMM is unnecessary - "config.autotune_multi_device = False", "config.max_autotune_gemm_backends = 'TRITON,ATEN,CPP'", # Not much faster + "config.autotune_fallback_to_aten = True", # Fallback to ATEN backend + "config.autotune_multi_device = True", # If autotuning in subprocess, whether to use multiple devices + "config.coordinate_descent_tuning = True", f"config.aggressive_fusion = {O3}", # Careful changes results! + "config.combo_kernels = True", # Experimental - enable the combo kernel that combines data-independent kernels + "config.combo_kernel_foreach_dynamic_shapes = True", + "config.freezing = False", # Freezes weights --> ** only useful for inference ** + "config.triton.multi_kernel = True", # use tuning to pick between different subkernels "config.cuda.enable_cuda_lto = True", "config.cuda.use_fast_math = True", "config.cuda.compile_opt_level = '-O2'", @@ -79,17 +90,20 @@ def patch_torch_compile(debug = True, O3 = False): f"config.do_not_emit_runtime_asserts = {not debug}", "config.cache_size_limit = 1024", # Flex Attention "config.inline_inbuilt_nn_modules = True", # Torch 2.5 Regional recompilation + "config.numpy_default_float = 'float32'", + "config.compiled_autograd = True", # New Torch 2.4 feature which can compile backwards passes + # https://pytorch.org/tutorials/intermediate/compiled_autograd_tutorial.html ] - # import torch._inductor.config as config - # for _try_compile_argument in torch_compile_arguments: - # try: exec(_try_compile_argument) - # except: pass - # pass - # import torch._dynamo.config as config - # for _try_dynamo_argument in torch_dynamo_arguments: - # try: exec(_try_dynamo_argument) - # except: pass - # pass + import torch._inductor.config as config + for _try_compile_argument in torch_compile_arguments: + try: exec(_try_compile_argument) + except: pass + pass + import torch._dynamo.config as config + for _try_dynamo_argument in torch_dynamo_arguments: + try: exec(_try_dynamo_argument) + except: pass + pass pass From a768c1b4009e5c74baaa394949e71254b516c8f8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 Nov 2024 18:04:23 -0800 Subject: [PATCH 15/22] Update patching_utils.py --- unsloth_zoo/patching_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index 776e57f74..bae5d8d48 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -72,7 +72,7 @@ def patch_torch_compile(debug = True, O3 = False): f"config.max_autotune_gemm = {O3}", # GEMM is unnecessary "config.max_autotune_gemm_backends = 'TRITON,ATEN,CPP'", # Not much faster "config.autotune_fallback_to_aten = True", # Fallback to ATEN backend - "config.autotune_multi_device = True", # If autotuning in subprocess, whether to use multiple devices + "config.autotune_multi_device = False", # If autotuning in subprocess, whether to use multiple devices "config.coordinate_descent_tuning = True", f"config.aggressive_fusion = {O3}", # Careful changes results! "config.combo_kernels = True", # Experimental - enable the combo kernel that combines data-independent kernels From f2a88780a0a1ee26de1423330e3bad4b511c3248 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 Nov 2024 18:25:46 -0800 Subject: [PATCH 16/22] Update patching_utils.py --- unsloth_zoo/patching_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index bae5d8d48..a316bde0a 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -66,13 +66,13 @@ def patch_torch_compile(debug = True, O3 = False): "config.memory_pool = 'combined'", "config.efficient_conv_bn_eval_fx_passes = True", # Reduces stability a little bit "config.dynamic_scale_rblock = True", # Scale down RBLOCK for better occupancy - "config.reorder_for_compute_comm_overlap = True", # # enable reordering pass for increasing overlap between compute and communication + # "config.reorder_for_compute_comm_overlap = True", # # enable reordering pass for increasing overlap between compute and communication f"config.max_autotune = {O3}", # enable slow autotuning passes to select algorithms f"config.max_autotune_pointwise = {O3}", # enable slow autotuning passes to select pointwise/reductions algorithms f"config.max_autotune_gemm = {O3}", # GEMM is unnecessary "config.max_autotune_gemm_backends = 'TRITON,ATEN,CPP'", # Not much faster "config.autotune_fallback_to_aten = True", # Fallback to ATEN backend - "config.autotune_multi_device = False", # If autotuning in subprocess, whether to use multiple devices + "config.autotune_multi_device = True", # If autotuning in subprocess, whether to use multiple devices "config.coordinate_descent_tuning = True", f"config.aggressive_fusion = {O3}", # Careful changes results! "config.combo_kernels = True", # Experimental - enable the combo kernel that combines data-independent kernels From beee64abe47a83c88936c6380b510ff61fff415c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 Nov 2024 18:36:55 -0800 Subject: [PATCH 17/22] Update patching_utils.py --- unsloth_zoo/patching_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index a316bde0a..1ead2b049 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -66,6 +66,7 @@ def patch_torch_compile(debug = True, O3 = False): "config.memory_pool = 'combined'", "config.efficient_conv_bn_eval_fx_passes = True", # Reduces stability a little bit "config.dynamic_scale_rblock = True", # Scale down RBLOCK for better occupancy + # Disable reorder_for_compute_comm_overlap since it errors for non multi GPU systems # "config.reorder_for_compute_comm_overlap = True", # # enable reordering pass for increasing overlap between compute and communication f"config.max_autotune = {O3}", # enable slow autotuning passes to select algorithms f"config.max_autotune_pointwise = {O3}", # enable slow autotuning passes to select pointwise/reductions algorithms From 408f3279426c65ce3fd40b673287336d39c4aa18 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 Nov 2024 22:45:14 -0800 Subject: [PATCH 18/22] Update patching_utils.py --- unsloth_zoo/patching_utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index 1ead2b049..e8bce4e42 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -58,6 +58,13 @@ def forward(self, X): def patch_torch_compile(debug = True, O3 = False): assert(type(debug) is bool) assert(type(O3) is bool) + import os + if debug: + os.environ["TORCHDYNAMO_VERBOSE"] = "1" + else: + os.environ.pop("TORCHDYNAMO_VERBOSE", None) + pass + # Torch compile arguments torch_compile_arguments = [ f"config.debug = {debug}", From 0a8385bcbc4f71bfb3ec25fc10a293981c2b44c2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 Nov 2024 22:47:46 -0800 Subject: [PATCH 19/22] Update patching_utils.py --- unsloth_zoo/patching_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index e8bce4e42..4a9e6c2e2 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -61,10 +61,12 @@ def patch_torch_compile(debug = True, O3 = False): import os if debug: os.environ["TORCHDYNAMO_VERBOSE"] = "1" + os.environ["TORCH_LOGS"] = "+dynamo" else: os.environ.pop("TORCHDYNAMO_VERBOSE", None) + os.environ.pop("TORCH_LOGS", None) pass - + # Torch compile arguments torch_compile_arguments = [ f"config.debug = {debug}", From 664fd43a099ecab31932c67543225cda0f0b85f5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 Nov 2024 22:53:27 -0800 Subject: [PATCH 20/22] Update patching_utils.py --- unsloth_zoo/patching_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index 4a9e6c2e2..12e0b8d8f 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -58,10 +58,12 @@ def forward(self, X): def patch_torch_compile(debug = True, O3 = False): assert(type(debug) is bool) assert(type(O3) is bool) - import os + import os, logging if debug: os.environ["TORCHDYNAMO_VERBOSE"] = "1" os.environ["TORCH_LOGS"] = "+dynamo" + torch._logging.set_logs(dynamo = logging.DEBUG) + torch._dynamo.config.verbose = True else: os.environ.pop("TORCHDYNAMO_VERBOSE", None) os.environ.pop("TORCH_LOGS", None) From 0c120905ce0cc2da7169f2ab8f3d73fd49c6aca6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 Nov 2024 23:41:56 -0800 Subject: [PATCH 21/22] Update patching_utils.py --- unsloth_zoo/patching_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py index 12e0b8d8f..1ccf19563 100644 --- a/unsloth_zoo/patching_utils.py +++ b/unsloth_zoo/patching_utils.py @@ -62,7 +62,7 @@ def patch_torch_compile(debug = True, O3 = False): if debug: os.environ["TORCHDYNAMO_VERBOSE"] = "1" os.environ["TORCH_LOGS"] = "+dynamo" - torch._logging.set_logs(dynamo = logging.DEBUG) + torch._logging.set_logs(dynamo = logging.DEBUG, inductor = logging.DEBUG) torch._dynamo.config.verbose = True else: os.environ.pop("TORCHDYNAMO_VERBOSE", None) From 03b91cd1206f74b58cb2e8dc44bbf3811b00e9c9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 5 Nov 2024 12:05:25 -0800 Subject: [PATCH 22/22] Update loss_utils.py --- unsloth_zoo/loss_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth_zoo/loss_utils.py b/unsloth_zoo/loss_utils.py index 756da0a57..22a1d681b 100644 --- a/unsloth_zoo/loss_utils.py +++ b/unsloth_zoo/loss_utils.py @@ -74,7 +74,8 @@ def UnslothForCausalLMLoss( LOSS_MAPPING["ForCausalLM"] = UnslothForCausalLMLoss # Remove @property and @lru_cache - if hasattr(transformers.modeling_utils.PreTrainedModel.loss_function, "fget"): + if hasattr(transformers.modeling_utils.PreTrainedModel.loss_function, "fget") and \ + hasattr(transformers.modeling_utils.PreTrainedModel.loss_function.fget, "__wrapped__"): transformers.modeling_utils.PreTrainedModel.loss_function = \ transformers.modeling_utils.PreTrainedModel.loss_function.fget.__wrapped__ pass