diff --git a/pyproject.toml b/pyproject.toml index 01a610c21..2c9b740d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ dependencies = [ "triton", "packaging", "tyro", - "transformers>=4.44.2", + "transformers>=4.46.1", "datasets>=2.16.0", "sentencepiece>=0.2.0", "tqdm", diff --git a/unsloth_zoo/__init__.py b/unsloth_zoo/__init__.py index fca09740c..6b2f350bd 100644 --- a/unsloth_zoo/__init__.py +++ b/unsloth_zoo/__init__.py @@ -16,14 +16,22 @@ __version__ = "2024.11.1" -import importlib.util -if importlib.util.find_spec("unsloth") is None: +from importlib.util import find_spec +if find_spec("unsloth") is None: raise ImportError("Please install Unsloth via `pip install unsloth`!") pass -del importlib.util +del find_spec import os if not ("UNSLOTH_IS_PRESENT" in os.environ): raise ImportError("Please install Unsloth via `pip install unsloth`!") pass + +try: + print("🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.") +except: + print("Unsloth: Will patch your computer to enable 2x faster free finetuning.") +pass +# Log Unsloth-Zoo Utilities +os.environ["UNSLOTH_ZOO_IS_PRESENT"] = "1" del os diff --git a/unsloth_zoo/gradient_checkpointing.py b/unsloth_zoo/gradient_checkpointing.py index 768068146..4e256c44e 100644 --- a/unsloth_zoo/gradient_checkpointing.py +++ b/unsloth_zoo/gradient_checkpointing.py @@ -24,6 +24,11 @@ "prepare_n_gradient_checkpoints", "Unsloth_Offloaded_Gradient_Checkpointer", "unsloth_offloaded_gradient_checkpoint", + "patch_unsloth_gradient_checkpointing", + "unpatch_unsloth_gradient_checkpointing", + + "Unsloth_Gradient_Checkpointer", + "unsloth_gradient_checkpoint", "patch_gradient_checkpointing", "unpatch_gradient_checkpointing", ] @@ -155,6 +160,35 @@ def backward(ctx, dY): pass +class Unsloth_Gradient_Checkpointer(torch.autograd.Function): + """ + Same as normal gradient checkpointing but cleaner + """ + @staticmethod + @torch_amp_custom_fwd + def forward(ctx, forward_function, hidden_states, *args): + with torch.no_grad(): + output = forward_function(hidden_states, *args) + ctx.save_for_backward(hidden_states) + ctx.forward_function = forward_function + ctx.args = args + return output + pass + + @staticmethod + @torch_amp_custom_bwd + def backward(ctx, dY): + (hidden_states,) = ctx.saved_tensors + hidden_states = hidden_states.detach() + hidden_states.requires_grad_(True) + with torch.enable_grad(): + (output,) = ctx.forward_function(hidden_states, *ctx.args) + torch.autograd.backward(output, dY) + return (None, hidden_states.grad,) + (None,)*len(ctx.args) + pass +pass + + def unsloth_offloaded_gradient_checkpoint(function, *args, use_reentrant = None, **kwargs): return Unsloth_Offloaded_Gradient_Checkpointer.apply(function, *args) pass @@ -166,8 +200,19 @@ def unsloth_offloaded_gradient_checkpoint(function, *args, use_reentrant = None, pass -def patch_gradient_checkpointing(): - print("Unsloth: Patching Gradient Checkpointing with Unsloth's special version!") +def unsloth_gradient_checkpoint(function, *args, use_reentrant = None, **kwargs): + return Unsloth_Gradient_Checkpointer.apply(function, *args) +pass +if (Version(torch.__version__) < Version("2.4.0")) and \ + not hasattr(unsloth_gradient_checkpoint, "__wrapped__"): + unsloth_gradient_checkpoint = torch._disable_dynamo( + unsloth_gradient_checkpoint + ) +pass + + +def patch_unsloth_gradient_checkpointing(): + print("Unsloth: Patched gradient checkpointing for long context finetuning.") import torch.utils if torch.utils.checkpoint.checkpoint.__name__ == "unsloth_offloaded_gradient_checkpoint": return torch.utils.checkpoint._old_checkpoint = torch.utils.checkpoint.checkpoint @@ -175,6 +220,24 @@ def patch_gradient_checkpointing(): pass +def patch_gradient_checkpointing(): + print("Unsloth: Patched gradient checkpointing.") + import torch.utils + if torch.utils.checkpoint.checkpoint.__name__ == "unsloth_gradient_checkpoint": return + torch.utils.checkpoint._old_checkpoint = torch.utils.checkpoint.checkpoint + torch.utils.checkpoint.checkpoint = unsloth_gradient_checkpoint +pass + + +def unpatch_unsloth_gradient_checkpointing(): + import torch.utils + if hasattr(torch.utils.checkpoint, "_old_checkpoint"): + torch.utils.checkpoint.checkpoint = torch.utils.checkpoint._old_checkpoint + del torch.utils.checkpoint._old_checkpoint + pass +pass + + def unpatch_gradient_checkpointing(): import torch.utils if hasattr(torch.utils.checkpoint, "_old_checkpoint"): diff --git a/unsloth_zoo/loss_utils.py b/unsloth_zoo/loss_utils.py index 5fd4d4187..22a1d681b 100644 --- a/unsloth_zoo/loss_utils.py +++ b/unsloth_zoo/loss_utils.py @@ -16,63 +16,74 @@ import torch from packaging.version import Version +torch_nn_functional_cross_entropy = torch.nn.functional.cross_entropy __all__ = [ - "causal_loss_function", - "transformers_losses_patcher", - "patch_loss_function", + "patch_loss_functions", + "post_patch_loss_function", ] -def causal_loss_function(_fast_cross_entropy_loss): +def patch_loss_functions(_fast_cross_entropy_loss): + try: + import transformers.loss.loss_utils + except: + print("Unsloth: Cannot patch loss functions - update transformers for faster modules!") + return None + pass + + # Generic cross entropy loss + def unsloth_fixed_cross_entropy(source, target, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs): + if ignore_index == -100: + loss = _fast_cross_entropy_loss( + logits = source, + labels = target, + n_items = num_items_in_batch, + ) + else: + reduction = "sum" if num_items_in_batch is not None else "mean" + loss = torch_nn_functional_cross_entropy( + source, + target, + ignore_index = ignore_index, + reduction = reduction, + ) + if reduction == "sum": loss = loss / num_items_in_batch + return loss + pass + + # Causal LM loss def UnslothForCausalLMLoss( logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs ): shift_logits = logits shift_labels = torch.empty_like(labels) shift_labels[..., :-1] = labels[..., 1:] - shift_labels[..., -1] = -100 - loss = _fast_cross_entropy_loss( - logits = shift_logits, - labels = shift_labels, - n_items = num_items_in_batch, - ) + shift_labels[..., -1] = ignore_index + loss = unsloth_fixed_cross_entropy(shift_logits, shift_labels, num_items_in_batch, ignore_index, **kwargs) return loss pass if (Version(torch.__version__) < Version("2.4.0")): UnslothForCausalLMLoss = torch._disable_dynamo(UnslothForCausalLMLoss) pass - return UnslothForCausalLMLoss -pass - -def transformers_losses_patcher(UnslothForCausalLMLoss): - def _patch_transformers_losses(): - import re - try: - import transformers.loss.loss_utils - except: - print("Unsloth: Cannot patch loss functions - update transformers for faster modules!") - return - pass + # Now patch the losses! + import transformers.modeling_utils + LOSS_MAPPING = transformers.loss.loss_utils.LOSS_MAPPING + LOSS_MAPPING["ForCausalLM"] = UnslothForCausalLMLoss - import transformers.modeling_utils - LOSS_MAPPING = transformers.loss.loss_utils.LOSS_MAPPING - LOSS_MAPPING["ForCausalLM"] = UnslothForCausalLMLoss - - # Remove @property and @lru_cache - if hasattr(transformers.modeling_utils.PreTrainedModel.loss_function, "fget"): - transformers.modeling_utils.PreTrainedModel.loss_function = \ - transformers.modeling_utils.PreTrainedModel.loss_function.fget.__wrapped__ - pass - print("Unsloth: Patched cross entropy losses.") + # Remove @property and @lru_cache + if hasattr(transformers.modeling_utils.PreTrainedModel.loss_function, "fget") and \ + hasattr(transformers.modeling_utils.PreTrainedModel.loss_function.fget, "__wrapped__"): + transformers.modeling_utils.PreTrainedModel.loss_function = \ + transformers.modeling_utils.PreTrainedModel.loss_function.fget.__wrapped__ pass - return _patch_transformers_losses + print("Unsloth: Patched cross entropy losses.") pass -def patch_loss_function(model): +def post_patch_loss_function(model): try: # model.loss_function starts as a dict to a loss fx # We invoke it to save it diff --git a/unsloth_zoo/patching_utils.py b/unsloth_zoo/patching_utils.py new file mode 100644 index 000000000..1ccf19563 --- /dev/null +++ b/unsloth_zoo/patching_utils.py @@ -0,0 +1,283 @@ +# Unsloth Zoo - Utilities for Unsloth +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see . + +import torch + +__all__ = [ + "patch_compiling_bitsandbytes", + "patch_layernorm", + "patch_torch_compile", + "patch_regional_compilation", + "patch_model_and_tokenizer", +] + +# Also disable compiling on bitsandbytes +def patch_compiling_bitsandbytes(): + # import peft.tuners.lora.bnb + # peft.tuners.lora.bnb.Linear4bit.forward = \ + # torch._disable_dynamo(peft.tuners.lora.bnb.Linear4bit.forward) + # peft.tuners.lora.bnb.Linear8bitLt.forward = \ + # torch._disable_dynamo(peft.tuners.lora.bnb.Linear8bitLt.forward) + # return + import bitsandbytes.nn.modules + bitsandbytes.nn.modules.Linear4bit.forward = \ + torch._disable_dynamo(bitsandbytes.nn.modules.Linear4bit.forward) + return +pass + + +def patch_layernorm(fast_layernorm): + import torch.nn + if torch.nn.LayerNorm.__name__ != "Unsloth_LayerNorm": + + from torch.nn import LayerNorm + class Unsloth_LayerNorm(LayerNorm): + def forward(self, X): + return fast_layernorm(self, X) + pass + pass + + torch.nn.LayerNorm = Unsloth_LayerNorm + return +pass + + +def patch_torch_compile(debug = True, O3 = False): + assert(type(debug) is bool) + assert(type(O3) is bool) + import os, logging + if debug: + os.environ["TORCHDYNAMO_VERBOSE"] = "1" + os.environ["TORCH_LOGS"] = "+dynamo" + torch._logging.set_logs(dynamo = logging.DEBUG, inductor = logging.DEBUG) + torch._dynamo.config.verbose = True + else: + os.environ.pop("TORCHDYNAMO_VERBOSE", None) + os.environ.pop("TORCH_LOGS", None) + pass + + # Torch compile arguments + torch_compile_arguments = [ + f"config.debug = {debug}", + "config.dce = True", + "config.memory_planning = True", + "config.memory_pool = 'combined'", + "config.efficient_conv_bn_eval_fx_passes = True", # Reduces stability a little bit + "config.dynamic_scale_rblock = True", # Scale down RBLOCK for better occupancy + # Disable reorder_for_compute_comm_overlap since it errors for non multi GPU systems + # "config.reorder_for_compute_comm_overlap = True", # # enable reordering pass for increasing overlap between compute and communication + f"config.max_autotune = {O3}", # enable slow autotuning passes to select algorithms + f"config.max_autotune_pointwise = {O3}", # enable slow autotuning passes to select pointwise/reductions algorithms + f"config.max_autotune_gemm = {O3}", # GEMM is unnecessary + "config.max_autotune_gemm_backends = 'TRITON,ATEN,CPP'", # Not much faster + "config.autotune_fallback_to_aten = True", # Fallback to ATEN backend + "config.autotune_multi_device = True", # If autotuning in subprocess, whether to use multiple devices + "config.coordinate_descent_tuning = True", + f"config.aggressive_fusion = {O3}", # Careful changes results! + "config.combo_kernels = True", # Experimental - enable the combo kernel that combines data-independent kernels + "config.combo_kernel_foreach_dynamic_shapes = True", + "config.freezing = False", # Freezes weights --> ** only useful for inference ** + "config.triton.multi_kernel = True", # use tuning to pick between different subkernels + "config.cuda.enable_cuda_lto = True", + "config.cuda.use_fast_math = True", + "config.cuda.compile_opt_level = '-O2'", + ] + # Torch dynamo arguments + torch_dynamo_arguments = [ + "config.accumulated_cache_size_limit = 1024", # Bump up a bit from 256 + f"config.suppress_errors = {not debug}", # Supress errors for now + f"config.do_not_emit_runtime_asserts = {not debug}", + "config.cache_size_limit = 1024", # Flex Attention + "config.inline_inbuilt_nn_modules = True", # Torch 2.5 Regional recompilation + "config.numpy_default_float = 'float32'", + "config.compiled_autograd = True", # New Torch 2.4 feature which can compile backwards passes + # https://pytorch.org/tutorials/intermediate/compiled_autograd_tutorial.html + ] + import torch._inductor.config as config + for _try_compile_argument in torch_compile_arguments: + try: exec(_try_compile_argument) + except: pass + pass + import torch._dynamo.config as config + for _try_dynamo_argument in torch_dynamo_arguments: + try: exec(_try_dynamo_argument) + except: pass + pass +pass + + +def patch_regional_compilation(): + # Regional torch 2.5 Recompilation - weirdly very slow?? + if torch.nn.ModuleList.__name__ == "UnslothModuleList": return + # Only works for torch 2.5 + if Version(torch.__version__) < Version("2.5.0"): return + + old_module_list = torch.nn.ModuleList + + def UnslothModuleList(*args, **kwargs): + if len(args) == 1 and len(kwargs) == 0 and type(args[0]) is list: + args = [old_module_list([torch.compile(x, dynamic = True, options = torch_compile_options, fullgraph = False) for x in args[0]])] + return old_module_list(*args, **kwargs) + pass + UnslothModuleList.__doc__ = old_module_list.__doc__ + + torch.nn.ModuleList = UnslothModuleList + return +pass + + +def patch_model_and_tokenizer(model, tokenizer, downcast_rope = True): + assert(type(downcast_rope) is bool) + import gc + + # Torch.compile fails on embedding matrix?? + try: old_input_embedding = model.get_input_embeddings ().weight + except: return model, tokenizer + + # Maybe not all models have a lm_head? + try: old_output_embedding = model.get_output_embeddings().weight + except: old_output_embedding = torch.zeros(0) + + # Check for tied weights as well + is_tied = (old_input_embedding.data_ptr() == old_output_embedding.data_ptr()) \ + or (model.config.tie_word_embeddings) + + # Check pad token's id -> we need to expand the embedding + if len(tokenizer) > old_input_embedding.shape[0]: + # Workaround randomnly fixes it for torch versions < 2. + requires_grad = old_input_embedding.requires_grad + old_input_embedding.requires_grad_(False) + old_input_embedding.resize_(len(tokenizer), old_input_embedding.shape[1]) + old_input_embedding.requires_grad_(requires_grad) + + # Fix up all vocab sizes + current_model = model + while hasattr(current_model, "model") and hasattr(current_model, "config"): + if hasattr(current_model.config, "vocab_size"): + current_model.config.update({"vocab_size" : len(tokenizer)}) + current_model.update({"unsloth_optimized" : True}) + current_model = current_model.model + if hasattr(current_model, "model") and hasattr(current_model, "config"): + if hasattr(current_model.config, "vocab_size"): + current_model.config.update({"vocab_size" : len(tokenizer)}) + current_model.update({"unsloth_optimized" : True}) + pass + pass + + model.set_input_embeddings( + torch.nn.Embedding.from_pretrained( + old_input_embedding, + padding_idx = getattr(model.config, "pad_token_id", None), + ) + ) + + # We also do this for the lm_head + if old_output_embedding.numel() != 0: + + requires_grad = old_output_embedding.requires_grad + lm_head = torch.nn.Linear(1, 1, bias = None) + del lm_head.weight + + lm_head.weight = old_output_embedding if not is_tied else old_input_embedding + lm_head.in_features = lm_head.weight.shape[1] + lm_head.out_features = lm_head.weight.shape[0] + + lm_head.weight.requires_grad_(requires_grad) + model.set_output_embeddings(lm_head) + if hasattr(model, "lm_head"): model.lm_head = lm_head + + correct_dtype = lm_head.weight.dtype + else: + correct_dtype = old_input_embedding.dtype + pass + + # Must tie lm_head and embed_tokens if they are tied! + # Otherwise error will occur on saving models ie use save_model + if is_tied: model.tie_weights() + + # Also fix torch_dtype + internal_model = model + while hasattr(internal_model, "model"): + if hasattr(internal_model, "config"): + if internal_model.config.torch_dtype == "float32": + internal_model.config.torch_dtype = torch.float32 + elif internal_model.config.torch_dtype == "bfloat16": + internal_model.config.torch_dtype = torch.bfloat16 + elif internal_model.config.torch_dtype == "float16": + internal_model.config.torch_dtype = torch.float16 + pass + pass + internal_model = internal_model.model + pass + if hasattr(internal_model, "config"): + if internal_model.config.torch_dtype == "float32": + internal_model.config.torch_dtype = torch.float32 + elif internal_model.config.torch_dtype == "bfloat16": + internal_model.config.torch_dtype = torch.bfloat16 + elif internal_model.config.torch_dtype == "float16": + internal_model.config.torch_dtype = torch.float16 + pass + pass + + # Also patch all dtypes - BnB seems to not allocate the correct type? + # BnB default dtype seems to be float16! + try: + from bitsandbytes.nn import Linear4bit as Bnb_Linear4bit + except: + raise ImportError("Unsloth: Please install bitsandbytes via `pip install bitsandbytes`") + try: + from peft.tuners.lora import Linear4bit as Peft_Linear4bit + except: + raise ImportError("Unsloth: Please install peft via `pip install peft`") + pass + + for name, module in model.named_modules(): + if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)): + weight = module.weight + quant_state = weight.quant_state + + if type(quant_state) is list: + # BnB seems to have float16 as default! + module.weight.quant_state[2] = correct_dtype # Cast to correct dtype + else: + # https://github.com/TimDettmers/bitsandbytes/pull/763/files + quant_state.dtype = correct_dtype + pass + pass + # Downcast RoPE embedding to correct data type + if downcast_rope and ((name.endswith("rotary_emb") or hasattr(module, "cos_cached"))): + + if hasattr(module, "cos_cached") and \ + (module.cos_cached.dtype != correct_dtype): + + module.cos_cached = module.cos_cached.to(correct_dtype) + module.sin_cached = module.sin_cached.to(correct_dtype) + + elif hasattr(module, "short_cos_cached") and \ + (module.short_cos_cached.dtype != correct_dtype): + + module.short_cos_cached = module.short_cos_cached.to(correct_dtype) + module.short_sin_cached = module.short_sin_cached.to(correct_dtype) + pass + pass + pass + + # Clear deleted GPU items + for _ in range(3): + gc.collect() + torch.cuda.empty_cache() + return model, tokenizer +pass diff --git a/unsloth_zoo/tokenizer_utils.py b/unsloth_zoo/tokenizer_utils.py index 5fcc4a26d..08ddf1d8d 100644 --- a/unsloth_zoo/tokenizer_utils.py +++ b/unsloth_zoo/tokenizer_utils.py @@ -104,10 +104,16 @@ def add_new_tokens( mean_lm_head = mean_lm_head .to(torch.float32) # Get old lengths - old_input_length = model.get_input_embeddings ().weight.shape[0] - old_output_length = model.get_output_embeddings().weight.shape[0] + old_input_embedding = model.get_input_embeddings ().weight + old_output_embedding = model.get_output_embeddings().weight + old_input_length = old_input_embedding .shape[0] + old_output_length = old_output_embedding.shape[0] old_config_size = model.config.vocab_size + # Check for tied weights as well + is_tied = (old_input_embedding.data_ptr() == old_output_embedding.data_ptr()) \ + or (model.config.tie_word_embeddings) + # Add tokens! old_length = len(tokenizer) tokenizer.add_tokens(new_tokens) @@ -165,7 +171,26 @@ def add_new_tokens( internal_model = internal_model.model pass internal_model._need_to_train_embeddings = True - + + # Fix up all vocab sizes + current_model = model + while hasattr(current_model, "model") and hasattr(current_model, "config"): + if hasattr(current_model.config, "vocab_size"): + current_model.config.update({"vocab_size" : len(tokenizer)}) + current_model = current_model.model + if hasattr(current_model, "model") and hasattr(current_model, "config"): + if hasattr(current_model.config, "vocab_size"): + current_model.config.update({"vocab_size" : len(tokenizer)}) + pass + + # Must tie lm_head and embed_tokens if they are tied! + # Otherwise error will occur on saving models ie use save_model + if is_tied: model.tie_weights() + + # Clear deleted GPU items + for _ in range(3): + gc.collect() + torch.cuda.empty_cache() return pass @@ -381,9 +406,6 @@ def patch_tokenizer(model, tokenizer): joiner = "\1\0=+=\0\1" number_repetitions = 3 - 1 # Number of reserved tokens needed - if model is not None: - model.config.update({"unsloth_version" : __version__}) - bad_pad_token = False if hasattr(tokenizer, "pad_token") and tokenizer.pad_token is not None: # Check if pad_token is not the same as eos_token otherwise the loss will ignore it!!