From f0aca9073edd18a21f37fe112391f3e840474f7c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 21 Oct 2024 01:02:53 -0700 Subject: [PATCH 001/209] Fix TRL --- pyproject.toml | 4 ++-- unsloth/models/_utils.py | 4 ++-- unsloth/tokenizer_utils.py | 5 +++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 455f8477e8..a2a9c2c939 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ huggingface = [ "wheel>=0.42.0", "numpy", "accelerate>=0.34.1", - "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,<=0.11.1", + "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3", "peft>=0.7.1,!=0.11.0", "protobuf<4.0.0", "huggingface_hub", @@ -227,7 +227,7 @@ colab-new = [ ] colab-no-deps = [ "accelerate>=0.34.1", - "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,<=0.11.1", + "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3", "peft>=0.7.1", "xformers<0.0.27", "bitsandbytes>=0.43.3", diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 4aa8e1e592..2381f509f3 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1194,8 +1194,8 @@ def patch_gradient_accumulation_fix(Trainer): logger.warning_once( "Unsloth: We fixed a gradient accumulation bug, "\ "but it seems like you don't have the latest transformers version!\n"\ - "Please update transformers via:\n"\ - '`pip uninstall transformers -y && pip install --upgrade --no-cache-dir "git+https://github.com/huggingface/transformers.git"`' + "Please update transformers, TRL and unsloth via:\n"\ + '`pip install --upgrade --no-cache-dir unsloth git+https://github.com/huggingface/transformers.git git+https://github.com/huggingface/trl.git`' ) pass pass diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index ffe9933f47..9ad603cd54 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -884,7 +884,7 @@ def patch_sft_trainer_tokenizer(): check_text = \ "\n"\ - "test_text = dataset[0][dataset_text_field] if (formatting_func is None or not use_formatting_func) else formatting_func(dataset[0])[0]\n"\ + "test_text = dataset[0][dataset_text_field] if (formatting_func is not None and dataset_text_field is None) else formatting_func(dataset[0])[0]\n"\ "chat_template = getattr(tokenizer, 'chat_template', None)\n"\ "chat_template = '' if chat_template is None else chat_template\n"\ "has_bos_token_already = (test_text.startswith(tokenizer.bos_token) or tokenizer.bos_token in chat_template) "\ @@ -941,7 +941,8 @@ def patch_sft_trainer_tokenizer(): " from transformers import __version__ as transformers_version\n"\ " from packaging.version import Version\n"\ " if Version(transformers_version) <= Version('4.45.2'):\n"\ - " print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers and Unsloth!')\n"\ + " print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\\n'\\\n"\ + " '`pip install --upgrade --no-cache-dir unsloth git+https://github.com/huggingface/transformers.git git+https://github.com/huggingface/trl.git`')\n"\ "except:\n"\ " pass\n"\ "\n\n" From f4ae58540578676bcfc0ce59903fa91e27a2b5f4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 22 Oct 2024 00:29:30 -0700 Subject: [PATCH 002/209] Update mistral.py --- unsloth/models/mistral.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/unsloth/models/mistral.py b/unsloth/models/mistral.py index 15e9efc426..00dcc5cd1d 100644 --- a/unsloth/models/mistral.py +++ b/unsloth/models/mistral.py @@ -254,8 +254,9 @@ def MistralForCausalLM_fast_forward( shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]])) loss = fast_cross_entropy_loss( - logits = shift_logits, - labels = shift_labels, + logits = shift_logits, + labels = shift_labels, + n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None), ) pass From 106f213f405c6ef26515c8c1ac286b5d705ea175 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 22 Oct 2024 00:53:38 -0700 Subject: [PATCH 003/209] Patch processing_class --- unsloth/tokenizer_utils.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 9ad603cd54..e11783f7c5 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -831,6 +831,7 @@ def check_nvidia(): PRE_CHECK = check_nvidia() +import inspect from inspect import getsource import trl.trainer.sft_trainer from trl.trainer.sft_trainer import * @@ -869,6 +870,29 @@ def neftune_post_forward_hook(module, input, output): pass +def patch_trl_tokenizer_processing_class(trainer_name): + # New TRL removes tokenizer! + # We return it back! + exec(f"from trl import {trainer_name}") + if str(eval(f"{trainer_name}")__name__).startswith("Unsloth"): return None + exec(f"parameters = inspect.signature({trainer_name}).parameters") + if "tokenizer" in parameters: return None + + args = {key : value.default for key, value in parameters.items()} + args["tokenizer"] = None + new_args = args.copy() + del new_args["tokenizer"] + del new_args["processing_class"] + new_args = ",\n".join(f"{' '*12}{key} = {key}" for key in new_args) + \ + f",\n{' '*12}processing_class = tokenizer if tokenizer else processing_class" + args = ",\n".join(f"{' '*8}{key} = {value}" for key, value in args.items()) + args = f"{' '*4}def __init__(\n" + f"{' '*8}self,\n" + args + "):" + args += f"\n{' '*8}\n{' '*8}super().__init__(\n{new_args}\n{' '*8})" + new_class = f"""class Unsloth{trainer_name}({trainer_name}):\n{' '*4}{args}\n""" + return new_class +pass + + def patch_sft_trainer_tokenizer(): """ Patches the trainer with changes @@ -982,4 +1006,13 @@ def patch_sft_trainer_tokenizer(): pass pass +# Fix TRL trainers with removed tokenizer args (got replaced with processing_class) +for trainer_name in ("SFTTrainer", "DPOTrainer", "KTOTrainer"): + trainer_text = patch_trl_tokenizer_processing_class(trainer_name) + if trainer_text is None: continue + exec(trainer_text, globals()) + exec(f"trl.trainer.{trainer_name} = Unsloth{trainer_name}", globals()) +pass + +# FInally patch TRL tokenizer things patch_sft_trainer_tokenizer() From ef842120d2490ba456fb2a13417cca5406b0e9da Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 22 Oct 2024 00:55:47 -0700 Subject: [PATCH 004/209] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index e11783f7c5..9d69fd12ce 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -874,7 +874,7 @@ def patch_trl_tokenizer_processing_class(trainer_name): # New TRL removes tokenizer! # We return it back! exec(f"from trl import {trainer_name}") - if str(eval(f"{trainer_name}")__name__).startswith("Unsloth"): return None + if str(eval(f"{trainer_name}").__name__).startswith("Unsloth"): return None exec(f"parameters = inspect.signature({trainer_name}).parameters") if "tokenizer" in parameters: return None From 4f7c527ae0b87073610ac45de34428fa887a4663 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 22 Oct 2024 00:57:36 -0700 Subject: [PATCH 005/209] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 9d69fd12ce..1dbff4f57c 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -873,9 +873,9 @@ def neftune_post_forward_hook(module, input, output): def patch_trl_tokenizer_processing_class(trainer_name): # New TRL removes tokenizer! # We return it back! - exec(f"from trl import {trainer_name}") + exec(f"from trl import {trainer_name}", globals()) if str(eval(f"{trainer_name}").__name__).startswith("Unsloth"): return None - exec(f"parameters = inspect.signature({trainer_name}).parameters") + parameters = eval(f"inspect.signature({trainer_name}).parameters") if "tokenizer" in parameters: return None args = {key : value.default for key, value in parameters.items()} From aa2b20763e5730e30a0c71df319d7674826d7c8e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 22 Oct 2024 01:05:41 -0700 Subject: [PATCH 006/209] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 1dbff4f57c..3978f62e3b 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -1010,6 +1010,7 @@ def patch_sft_trainer_tokenizer(): for trainer_name in ("SFTTrainer", "DPOTrainer", "KTOTrainer"): trainer_text = patch_trl_tokenizer_processing_class(trainer_name) if trainer_text is None: continue + print(trainer_text) exec(trainer_text, globals()) exec(f"trl.trainer.{trainer_name} = Unsloth{trainer_name}", globals()) pass From 101389d728881b441dc123b60551579eadf5c3bc Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 22 Oct 2024 01:09:20 -0700 Subject: [PATCH 007/209] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 3978f62e3b..2f5295e017 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -886,7 +886,7 @@ def patch_trl_tokenizer_processing_class(trainer_name): new_args = ",\n".join(f"{' '*12}{key} = {key}" for key in new_args) + \ f",\n{' '*12}processing_class = tokenizer if tokenizer else processing_class" args = ",\n".join(f"{' '*8}{key} = {value}" for key, value in args.items()) - args = f"{' '*4}def __init__(\n" + f"{' '*8}self,\n" + args + "):" + args = f"def __init__(\n" + f"{' '*8}self,\n" + args + "):" args += f"\n{' '*8}\n{' '*8}super().__init__(\n{new_args}\n{' '*8})" new_class = f"""class Unsloth{trainer_name}({trainer_name}):\n{' '*4}{args}\n""" return new_class From c0f0fc987d31d374db03656c87047bdc40ce5e96 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 22 Oct 2024 01:22:13 -0700 Subject: [PATCH 008/209] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 2f5295e017..6e9c061fad 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -878,7 +878,13 @@ def patch_trl_tokenizer_processing_class(trainer_name): parameters = eval(f"inspect.signature({trainer_name}).parameters") if "tokenizer" in parameters: return None - args = {key : value.default for key, value in parameters.items()} + args = { + key : \ + value.default \ + if type(value.default) is not str else \ + f"'{value.default}'" \ + for key, value in parameters.items() + } args["tokenizer"] = None new_args = args.copy() del new_args["tokenizer"] From b3e00335c1b45f8093599ec5b151b5c6d6546952 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 22 Oct 2024 01:28:38 -0700 Subject: [PATCH 009/209] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 6e9c061fad..d7af61a090 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -914,6 +914,7 @@ def patch_sft_trainer_tokenizer(): check_text = \ "\n"\ + "if 'tokenizer' not in locals(): tokenizer = processing_class\n"\ "test_text = dataset[0][dataset_text_field] if (formatting_func is not None and dataset_text_field is None) else formatting_func(dataset[0])[0]\n"\ "chat_template = getattr(tokenizer, 'chat_template', None)\n"\ "chat_template = '' if chat_template is None else chat_template\n"\ From aabb5ff54b4f6baa7922e03e59d9cbaabb7ea4c9 Mon Sep 17 00:00:00 2001 From: timothelaborie <97834767+timothelaborie@users.noreply.github.com> Date: Wed, 23 Oct 2024 09:55:26 +0200 Subject: [PATCH 010/209] Installation guide (#1165) --- README.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/README.md b/README.md index 6b583d5fd5..d0e5a97999 100644 --- a/README.md +++ b/README.md @@ -181,6 +181,18 @@ x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "") print(f'pip install --upgrade pip && pip install "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git"') ``` +### Windows Installation + +To run Unsloth directly on Windows: +- Install Triton from this Windows fork and follow the instructions: https://github.com/woct0rdho/triton-windows +- In the SFTTrainer, set `dataset_num_proc=1` to avoid a crashing issue: +```python +trainer = SFTTrainer( + dataset_num_proc=1, + ... +) +``` + For **advanced installation instructions** or if you see weird errors during installations: 1. Install `torch` and `triton`. Go to https://pytorch.org to install it. For example `pip install torch torchvision torchaudio triton` From 30bf33957dd8822926661ccdb52a1ad908609683 Mon Sep 17 00:00:00 2001 From: Ikko Eltociear Ashimine Date: Wed, 23 Oct 2024 16:59:02 +0900 Subject: [PATCH 011/209] chore: update chat_templates.py (#1166) orginal -> original --- unsloth/chat_templates.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/chat_templates.py b/unsloth/chat_templates.py index cab6130dd1..b254202c75 100644 --- a/unsloth/chat_templates.py +++ b/unsloth/chat_templates.py @@ -678,7 +678,7 @@ {{- end }} {{- if .Tools }} -You are a helpful assistant with tool calling capabilities. When you receive a tool call response, use the output to format an answer to the orginal use question. +You are a helpful assistant with tool calling capabilities. When you receive a tool call response, use the output to format an answer to the original use question. {{- end }} {{- end }}<|eot_id|> {{- range $i, $_ := .Messages }} From 28958397ee376d0c06435c50d36f4e0d6484aca0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 23 Oct 2024 02:58:40 -0700 Subject: [PATCH 012/209] Disable Flex Attention --- unsloth/kernels/flex_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/kernels/flex_attention.py b/unsloth/kernels/flex_attention.py index 2fba359b77..08426b69e0 100644 --- a/unsloth/kernels/flex_attention.py +++ b/unsloth/kernels/flex_attention.py @@ -31,7 +31,7 @@ create_block_mask as _create_block_mask, ) _flex_attention = torch.compile(_flex_attention, dynamic = True, options = torch_compile_options) - HAS_FLEX_ATTENTION = True + HAS_FLEX_ATTENTION = False except: HAS_FLEX_ATTENTION = False pass From 06f5d75b811a70f94f1d97baf0970b450973df9f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 23 Oct 2024 03:04:22 -0700 Subject: [PATCH 013/209] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index d7af61a090..1cad00d44d 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -1017,7 +1017,6 @@ def patch_sft_trainer_tokenizer(): for trainer_name in ("SFTTrainer", "DPOTrainer", "KTOTrainer"): trainer_text = patch_trl_tokenizer_processing_class(trainer_name) if trainer_text is None: continue - print(trainer_text) exec(trainer_text, globals()) exec(f"trl.trainer.{trainer_name} = Unsloth{trainer_name}", globals()) pass From 28e6eeabd8f841a4440735c387cad5f4a492e879 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 23 Oct 2024 03:14:48 -0700 Subject: [PATCH 014/209] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 2381f509f3..25024be2ec 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2024.10.3" +__version__ = "2024.10.4" __all__ = [ "prepare_model_for_kbit_training", From b821f20b36cb1cf27bb1f6e928dc55b13a55ab15 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 23 Oct 2024 22:13:45 -0700 Subject: [PATCH 015/209] n_items --- unsloth/__init__.py | 10 +++++++--- unsloth/kernels/cross_entropy_loss.py | 1 + unsloth/tokenizer_utils.py | 11 ++++++++--- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index abee9c9e04..458c2696bc 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -62,9 +62,13 @@ try: import torch -except: - raise ImportError("Pytorch is not installed. Go to https://pytorch.org/.\n"\ - "We have some installation instructions on our Github page.") +except ModuleNotFoundError: + raise ImportError( + "Unsloth: Pytorch is not installed. Go to https://pytorch.org/.\n"\ + "We have some installation instructions on our Github page." + ) +except Exception as exception: + raise exception pass # Hugging Face Hub faster downloads (only enable during Colab and Kaggle sessions) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 1c8f8c8d99..35df257329 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -373,6 +373,7 @@ def fast_cross_entropy_loss( logit_softcapping, logit_scaling, ) + print(n_items) if n_items is None: n_items = torch.count_nonzero(labels != -100) return loss.sum() / n_items diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 1cad00d44d..8806f1e743 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -914,8 +914,10 @@ def patch_sft_trainer_tokenizer(): check_text = \ "\n"\ - "if 'tokenizer' not in locals(): tokenizer = processing_class\n"\ - "test_text = dataset[0][dataset_text_field] if (formatting_func is not None and dataset_text_field is None) else formatting_func(dataset[0])[0]\n"\ + "if 'tokenizer' not in locals(): tokenizer = processing_class\n"\ + "if 'formatting_func' not in locals(): raise RuntimeError('Unsloth: Please file a bug report - `formatting_func` does not exist!')\n"\ + "if 'dataset_text_field' not in locals(): raise RuntimeError('Unsloth: Please file a bug report - `dataset_text_field` does not exist!')\n"\ + "test_text = dataset[0][dataset_text_field] if (formatting_func is None and dataset_text_field is not None) else formatting_func(dataset[0])[0]\n"\ "chat_template = getattr(tokenizer, 'chat_template', None)\n"\ "chat_template = '' if chat_template is None else chat_template\n"\ "has_bos_token_already = (test_text.startswith(tokenizer.bos_token) or tokenizer.bos_token in chat_template) "\ @@ -1017,7 +1019,10 @@ def patch_sft_trainer_tokenizer(): for trainer_name in ("SFTTrainer", "DPOTrainer", "KTOTrainer"): trainer_text = patch_trl_tokenizer_processing_class(trainer_name) if trainer_text is None: continue - exec(trainer_text, globals()) + try: + exec(trainer_text, globals()) + except: + raise RuntimeError(f"Unsloth: Please file a bug report! Error patching {trainer_name}") exec(f"trl.trainer.{trainer_name} = Unsloth{trainer_name}", globals()) pass From e56136636540097f63bb11fa540558001d30b880 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 23 Oct 2024 22:18:24 -0700 Subject: [PATCH 016/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 35df257329..4895d27310 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -373,9 +373,9 @@ def fast_cross_entropy_loss( logit_softcapping, logit_scaling, ) - print(n_items) if n_items is None: n_items = torch.count_nonzero(labels != -100) + print(n_items) return loss.sum() / n_items pass From 4ff247ab18600fb3bc474ddd89dfdce76b50c287 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 24 Oct 2024 00:17:26 -0700 Subject: [PATCH 017/209] Fix DPO, ORPO --- unsloth/kernels/cross_entropy_loss.py | 1 - unsloth/models/_utils.py | 33 +++++++++++++++++++++++++-- unsloth/save.py | 2 +- 3 files changed, 32 insertions(+), 4 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 4895d27310..1c8f8c8d99 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -375,7 +375,6 @@ def fast_cross_entropy_loss( ) if n_items is None: n_items = torch.count_nonzero(labels != -100) - print(n_items) return loss.sum() / n_items pass diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 25024be2ec..6611b4f638 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1172,10 +1172,10 @@ def _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs): def patch_gradient_accumulation_fix(Trainer): # Fixes gradient accumulation + import inspect if hasattr(Trainer, "get_batch_samples"): - from inspect import getsource if \ - not getsource(Trainer.get_batch_samples).strip()\ + not inspect.getsource(Trainer.get_batch_samples).strip()\ .endswith("return batch_samples, num_items_in_batch"): raise NotImplementedError("Unsloth: Please make a Github issue immediately!!") @@ -1198,4 +1198,33 @@ def patch_gradient_accumulation_fix(Trainer): '`pip install --upgrade --no-cache-dir unsloth git+https://github.com/huggingface/transformers.git git+https://github.com/huggingface/trl.git`' ) pass + + # Also fix up loss scaling ie negate loss *= self.args.gradient_accumulation_steps + if "num_items_in_batch" not in inspect.signature(Trainer.training_step).parameters: return + + function = inspect.getsource(Trainer.training_step) + where = function.find("def") + function = function.split("\n") + function = "\n".join(x[where:] for x in function) + + # Import all variables that need importing + import transformers.trainer + items_in_trainer = dir(transformers.trainer) + good_items = [] + for item in items_in_trainer: + # TODO: Support Deepspeed + if item.startswith(("deepspeed", "xm", "met", "smp")): continue + if item in function: good_items.append(item) + pass + exec("from transformers.trainer import (" + ", ".join(x for x in good_items) + ")", globals()) + + # Accelerate does / self.args.gradient_accumulation_steps internally, so if we already + # summed it up and did the division before hand, we have to negate it. + function = function.replace( + "loss *= self.args.gradient_accumulation_steps", + "if num_items_in_batch is not None: loss *= self.args.gradient_accumulation_steps", + ) + function = function.replace("def training_step", "def _unsloth_training_step", 1) + exec(function, globals()) + Trainer.training_step = _unsloth_training_step pass diff --git a/unsloth/save.py b/unsloth/save.py index ab30e0fea5..ccda79aeee 100644 --- a/unsloth/save.py +++ b/unsloth/save.py @@ -145,7 +145,7 @@ def _free_cached_model(model): def _merge_lora(layer, name): - bias = None + bias = getattr(layer, "bias", None) if isinstance(layer, (Bnb_Linear4bit, Peft_Linear4bit, Peft_Linear)): # Is LoRA so we need to merge! W, quant_state, A, B, s, bias = get_lora_parameters_bias(layer) From 1c063b4c98a9f63e47ba86d887e996dc8dc12e2a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 24 Oct 2024 00:25:28 -0700 Subject: [PATCH 018/209] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 3724cdeab2..bf5216b228 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2024.10.5" +__version__ = "2024.10.6" __all__ = [ "prepare_model_for_kbit_training", From f195ee1e6567dcc14620961b353d2c42226a4a54 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 24 Oct 2024 01:11:20 -0700 Subject: [PATCH 019/209] Update _utils.py --- unsloth/models/_utils.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index bf5216b228..873a2723c2 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -162,6 +162,20 @@ def patch_mistral_nemo_config(config): pass # ============================================= +# ============================================= +# Weird Databricks errors +from transformers.utils import is_openai_available +if is_openai_available(): + try: + from openai import OpenAI + except: + print("Unsloth: OpenAI failed to import - ignoring for now.") + import transformers.utils + def _is_openai_available(): return False + transformers.utils.is_openai_available = _is_openai_available + pass +pass + # ============================================= # Get Flash Attention v2 if Ampere (RTX 30xx, A100) import bitsandbytes as bnb From faf27477aab01ed879b8d5d8d75fc6de2c5c05e5 Mon Sep 17 00:00:00 2001 From: Edd <68678137+Erland366@users.noreply.github.com> Date: Thu, 24 Oct 2024 23:10:52 +0400 Subject: [PATCH 020/209] fix/transformers-unpack (#1180) * Fix DPO, ORPO (#1177) * Fix TRL * Update mistral.py * Patch processing_class * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Installation guide (#1165) * chore: update chat_templates.py (#1166) orginal -> original * Disable Flex Attention * Update tokenizer_utils.py * Update _utils.py * n_items * Update cross_entropy_loss.py * Fix DPO, ORPO * Update _utils.py --------- Co-authored-by: timothelaborie <97834767+timothelaborie@users.noreply.github.com> Co-authored-by: Ikko Eltociear Ashimine * Add warning for missing Unpack and KwargsForCausalLM in older Transformers versions --------- Co-authored-by: Daniel Han Co-authored-by: timothelaborie <97834767+timothelaborie@users.noreply.github.com> Co-authored-by: Ikko Eltociear Ashimine --- unsloth/kernels/cross_entropy_loss.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 1c8f8c8d99..57d07af493 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -388,6 +388,13 @@ def fast_cross_entropy_loss( List, Tuple, ) + +try: + from transformers.models.llama.modeling_llama import Unpack, KwargsForCausalLM +except ImportError: + logger.warning("Unsloth: Could not find Unpack, KwargsForCausalLM in LlamaForCausalLM. " + "This is expected if you are using an older version of Transformers (<4.46.0). ") + import inspect, re function = inspect.getsource(LlamaForCausalLM.forward) function = function.split("\n") From 5961c34a71871aaf18335f0e493c38a02e48d458 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 24 Oct 2024 12:11:38 -0700 Subject: [PATCH 021/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 57d07af493..f2377d55cc 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -389,11 +389,12 @@ def fast_cross_entropy_loss( Tuple, ) +# Transformers 4.47 need Unpack, KwargsForCausalLM try: from transformers.models.llama.modeling_llama import Unpack, KwargsForCausalLM -except ImportError: - logger.warning("Unsloth: Could not find Unpack, KwargsForCausalLM in LlamaForCausalLM. " - "This is expected if you are using an older version of Transformers (<4.46.0). ") +except: + pass +pass import inspect, re function = inspect.getsource(LlamaForCausalLM.forward) From 7308bb82998322f2cae91ec217efd4e82fff9086 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 24 Oct 2024 12:14:14 -0700 Subject: [PATCH 022/209] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 873a2723c2..68e294f157 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2024.10.6" +__version__ = "2024.10.7" __all__ = [ "prepare_model_for_kbit_training", From 0096e5b07ff1642163168bda6e9c41338edb470c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 24 Oct 2024 12:17:09 -0700 Subject: [PATCH 023/209] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 68e294f157..873a2723c2 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2024.10.7" +__version__ = "2024.10.6" __all__ = [ "prepare_model_for_kbit_training", From 67760559e4b4a77add264c0e2814f6a13beea964 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Fri, 25 Oct 2024 13:58:12 +0530 Subject: [PATCH 024/209] donot upcast lm_head and embeddings to float32 (#1186) --- unsloth/models/llama.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 5f20f51209..9c5499dc75 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1955,7 +1955,7 @@ def get_peft_model( print("Unsloth: Casting embed_tokens to float32") model.model.model.embed_tokens.modules_to_save.default\ - .to(device = "cuda:0", dtype = torch.float32, non_blocking = True) + .to(device = "cuda:0", non_blocking = True) model.model.model.embed_tokens.modules_to_save.default.requires_grad_(True) # [TODO] Move old embed_tokens to CPU - should be disk! @@ -1968,7 +1968,7 @@ def get_peft_model( print("Unsloth: Casting lm_head to float32") model.model.lm_head.modules_to_save.default\ - .to(device = "cuda:0", dtype = torch.float32, non_blocking = True) + .to(device = "cuda:0", non_blocking = True) model.model.lm_head.modules_to_save.default.requires_grad_(True) # [TODO] Move old lm_head to CPU - should be disk! @@ -2206,7 +2206,7 @@ def get_peft_model( print("Unsloth: Casting embed_tokens to float32") assert(hasattr(model.model.model.embed_tokens, "modules_to_save")) model.model.model.embed_tokens.modules_to_save.default\ - .to(device = "cuda:0", dtype = torch.float32, non_blocking = True) + .to(device = "cuda:0", non_blocking = True) model.model.model.embed_tokens.modules_to_save.default.requires_grad_(True) pass @@ -2214,7 +2214,7 @@ def get_peft_model( print("Unsloth: Casting lm_head to float32") assert(hasattr(model.model.lm_head, "modules_to_save")) model.model.lm_head.modules_to_save.default\ - .to(device = "cuda:0", dtype = torch.float32, non_blocking = True) + .to(device = "cuda:0", non_blocking = True) model.model.lm_head.modules_to_save.default.requires_grad_(True) pass From 625209e11febd6d91a5edf8dfdfd04906a013c9f Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Sat, 26 Oct 2024 00:47:54 +0530 Subject: [PATCH 025/209] Cleanup upcast logs (#1188) --- unsloth/models/llama.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 9c5499dc75..cb0ce2feaf 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1952,7 +1952,7 @@ def get_peft_model( # Offload! # [TODO] First offload lm_head and embed_tokens to CPU (should be disk!!) if "embed_tokens" in new_target_modules: - print("Unsloth: Casting embed_tokens to float32") + print("Unsloth: Training embed_tokens in mixed precision to save VRAM") model.model.model.embed_tokens.modules_to_save.default\ .to(device = "cuda:0", non_blocking = True) @@ -1965,7 +1965,7 @@ def get_peft_model( pass if "lm_head" in new_target_modules: - print("Unsloth: Casting lm_head to float32") + print("Unsloth: Training lm_head in mixed precision to save VRAM") model.model.lm_head.modules_to_save.default\ .to(device = "cuda:0", non_blocking = True) @@ -2203,7 +2203,7 @@ def get_peft_model( # Now patch lm_head and embed_tokens if train_embed_tokens: - print("Unsloth: Casting embed_tokens to float32") + print("Unsloth: Training embed_tokens in mixed precision to save VRAM") assert(hasattr(model.model.model.embed_tokens, "modules_to_save")) model.model.model.embed_tokens.modules_to_save.default\ .to(device = "cuda:0", non_blocking = True) @@ -2211,7 +2211,7 @@ def get_peft_model( pass if train_lm_head: - print("Unsloth: Casting lm_head to float32") + print("Unsloth: Training lm_head in mixed precision to save VRAM") assert(hasattr(model.model.lm_head, "modules_to_save")) model.model.lm_head.modules_to_save.default\ .to(device = "cuda:0", non_blocking = True) From 2bc189f490a7e1c1f5431c326d0e4bb14858a2e4 Mon Sep 17 00:00:00 2001 From: Edd <68678137+Erland366@users.noreply.github.com> Date: Sat, 26 Oct 2024 02:44:10 +0400 Subject: [PATCH 026/209] Fix/phi-longrope (#1193) * Enhance rotary embedding handling in LlamaAttention and LongRopeRotaryEmbedding * Typo * Improve rotary embedding handling in LlamaAttention to prevent errors with short KV cache * Update llama.py * Update llama.py --------- Co-authored-by: Daniel Han --- unsloth/models/llama.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index cb0ce2feaf..c98feeca1e 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -193,6 +193,10 @@ def LlamaAttention_fast_forward_inference( # cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len) # Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids) + + # Need to do it prior 2 steps before hitting full on short KV cache + # or else error + self.rotary_emb.extend_rope_embedding(Vn, seq_len + 2) cos, sin = self.rotary_emb.get_cached(kv_seq_len) cos = cos[position_ids].unsqueeze(1) sin = sin[position_ids].unsqueeze(1) @@ -1122,7 +1126,7 @@ def get_cached(self, seq_len = None): def extend_rope_embedding(self, x, seq_len): if seq_len <= self.current_rope_size: return # Iteratively grow by increments of 8192 - self.current_rope_size = math.ceil(seq_len / 8192) * 8192 + self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192 self._set_cos_sin_cache(self.current_rope_size, device = "cuda:0", dtype = x.dtype) pass pass @@ -1248,7 +1252,7 @@ def get_cached(self, seq_len = None): def extend_rope_embedding(self, x, seq_len): if seq_len <= self.current_rope_size: return # Iteratively grow by increments of 8192 - self.current_rope_size = math.ceil(seq_len / 8192) * 8192 + self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192 self._set_cos_sin_cache(self.current_rope_size, device = "cuda:0", dtype = x.dtype) pass pass @@ -1363,7 +1367,7 @@ def get_cached(self, seq_len = None): def extend_rope_embedding(self, x, seq_len): if seq_len <= self.current_rope_size: return # Iteratively grow by increments of 8192 - self.current_rope_size = math.ceil(seq_len / 8192) * 8192 + self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192 self._set_cos_sin_cache(self.current_rope_size, device = "cuda:0", dtype = x.dtype) pass pass From 6f28d160b217dc6145b64f5168cfeeb771b2da38 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 26 Oct 2024 01:20:37 -0700 Subject: [PATCH 027/209] Update transformers --- unsloth/models/_utils.py | 2 +- unsloth/tokenizer_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 68e294f157..0acc8cd350 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1209,7 +1209,7 @@ def patch_gradient_accumulation_fix(Trainer): "Unsloth: We fixed a gradient accumulation bug, "\ "but it seems like you don't have the latest transformers version!\n"\ "Please update transformers, TRL and unsloth via:\n"\ - '`pip install --upgrade --no-cache-dir unsloth git+https://github.com/huggingface/transformers.git git+https://github.com/huggingface/trl.git`' + '`pip install --upgrade --no-cache-dir --no-deps unsloth transformers git+https://github.com/huggingface/trl.git`' ) pass diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 8806f1e743..c05485f902 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -975,7 +975,7 @@ def patch_sft_trainer_tokenizer(): " from packaging.version import Version\n"\ " if Version(transformers_version) <= Version('4.45.2'):\n"\ " print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\\n'\\\n"\ - " '`pip install --upgrade --no-cache-dir unsloth git+https://github.com/huggingface/transformers.git git+https://github.com/huggingface/trl.git`')\n"\ + " '`pip install --upgrade --no-cache-dir --no-deps unsloth transformers git+https://github.com/huggingface/trl.git`')\n"\ "except:\n"\ " pass\n"\ "\n\n" From 7083a1d455890ed276af9a7a4aee2b2bd655a2a8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 27 Oct 2024 17:32:26 -0700 Subject: [PATCH 028/209] Unk token issues --- unsloth/models/_utils.py | 48 ++++++++++++++++++++++++++++++++++++++-- unsloth/models/llama.py | 5 +++++ 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 51c63fc7dd..c0a15592e5 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -487,6 +487,50 @@ def patch_tokenizer(model, tokenizer): bad_pad_token = False pass + # Check if unknown token is broken + fixed_unk_token = False + + if hasattr(tokenizer, "unk_token") and tokenizer.unk_token is not None: + + eos_token = getattr(tokenizer, "eos_token", None) + bos_token = getattr(tokenizer, "bos_token", None) + + old_unk_token = tokenizer.unk_token + if old_unk_token == eos_token or old_unk_token == bos_token: + has_broken_unk = True + # Use the unicode replacement characters + possible_replacements = [ + "\uFFFD", # Original replacement char + "\uFFFC", # Another option + "\u2753", # Red Question mark emoji + "\u2754", # White Question mark emoji + "\u00BF", # Inverted question mark + ] + for replacement_char in possible_replacements: + char = tokenizer(replacement_char, add_special_tokens = False) + if len(char) == 1: + # Get actual token representation + try: char = tokenizer.convert_ids_to_tokens(char[0]) + except: continue + tokenizer.unk_token = char + fixed_unk_token = True + break + pass + pass + + if not fixed_unk_token: # Still broken! + raise RuntimeError( + f"Unsloth: Tried fixing the unk_token = {old_unk_token}, but couldn't!" + ) + pass + + logger.warning_once( + f"Unsloth: unk_token = {old_unk_token} is the same as the EOS or BOS tokens.\n"\ + f"We fixed it by changing it to {tokenizer.unk_token}." + ) + pass + pass + if bad_pad_token: # Find a better pad token added_tokens = [str(x) for x in tokenizer.added_tokens_decoder.values()] @@ -534,8 +578,8 @@ def patch_tokenizer(model, tokenizer): pass possible_pad_token = final_pad_token - # Try unk_token - if possible_pad_token is None and hasattr(tokenizer, "unk_token"): + # Try unk_token if it wasn't fixed + if possible_pad_token is None and not fixed_unk_token and hasattr(tokenizer, "unk_token"): possible_pad_token = tokenizer.unk_token pass diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index cf05d432c7..9c9ea53752 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1910,6 +1910,11 @@ def get_peft_model( ): transformers_set_seed(random_state) + if type(r) is not int: + raise TypeError(f"Unsloth: Rank of {str(r)} must be an integer.") + if r <= 0: + raise TypeError(f"Unsloth: Rank of {str(r)} must be larger than 0.") + if isinstance(model, PeftModelForCausalLM): # Check if exactly the same and then pass through! assert(hasattr(model, "peft_config")) From 3acc5afad3e7e93332c6252781a4ac85ea0267a5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 27 Oct 2024 17:34:33 -0700 Subject: [PATCH 029/209] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index c0a15592e5..d92a938eee 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -507,7 +507,7 @@ def patch_tokenizer(model, tokenizer): "\u00BF", # Inverted question mark ] for replacement_char in possible_replacements: - char = tokenizer(replacement_char, add_special_tokens = False) + char = tokenizer(replacement_char, add_special_tokens = False).input_ids if len(char) == 1: # Get actual token representation try: char = tokenizer.convert_ids_to_tokens(char[0]) From 1c044da660810c422b32041cbfbd1519ff2db6e9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 27 Oct 2024 19:06:57 -0700 Subject: [PATCH 030/209] Fix pad token --- unsloth/models/_utils.py | 27 ++++++++++++++-- unsloth/models/gemma.py | 7 ++--- unsloth/models/gemma2.py | 7 ++--- unsloth/models/llama.py | 68 ++++++++++++++++++++++++++++++---------- 4 files changed, 80 insertions(+), 29 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index d92a938eee..46a5f45cfa 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -525,7 +525,7 @@ def patch_tokenizer(model, tokenizer): pass logger.warning_once( - f"Unsloth: unk_token = {old_unk_token} is the same as the EOS or BOS tokens.\n"\ + f"Unsloth: unk_token = {old_unk_token} is the same as the EOS or BOS tokens. "\ f"We fixed it by changing it to {tokenizer.unk_token}." ) pass @@ -610,13 +610,34 @@ def patch_tokenizer(model, tokenizer): tokenizer.add_special_tokens({"pad_token" : possible_pad_token}) tokenizer.pad_token = possible_pad_token if model is not None: - model.config.update({"pad_token_id" : tokenizer.pad_token_id}) + + # Edit all config with new pad token + current_model = model + while hasattr(model, "model") and hasattr(model, "config"): + current_model.config.update({"pad_token_id" : tokenizer.pad_token_id}) + current_model = current_model.model + if hasattr(model, "model") and hasattr(model, "config"): + current_model.config.update({"pad_token_id" : tokenizer.pad_token_id}) + pass + + # Generation edit pad token if getattr(model, "generation_config") is not None: model.generation_config.update(pad_token_id = tokenizer.pad_token_id) else: if model is not None: + if model.config.pad_token_id is None: - model.config.update({"pad_token_id" : tokenizer.pad_token_id}) + + # Edit all config with new pad token + current_model = model + while hasattr(model, "model") and hasattr(model, "config"): + current_model.config.update({"pad_token_id" : tokenizer.pad_token_id}) + current_model = model + if hasattr(model, "model") and hasattr(model, "config"): + current_model.config.update({"pad_token_id" : tokenizer.pad_token_id}) + pass + + # Generation edit pad token if getattr(model, "generation_config") is not None: model.generation_config.update(pad_token_id = tokenizer.pad_token_id) pass diff --git a/unsloth/models/gemma.py b/unsloth/models/gemma.py index 45f14c1131..1ec116b2ea 100644 --- a/unsloth/models/gemma.py +++ b/unsloth/models/gemma.py @@ -339,10 +339,7 @@ def pre_patch(): @staticmethod - def post_patch(model): - # Patch model for Gemma - layers = model.model.layers - + def post_patch(model, tokenizer): # Torch.compile fails on embedding matrix?? # Workaround randomnly fixes it for torch versions < 2.2 model.model.embed_tokens = torch.nn.Embedding.from_pretrained(model.model.embed_tokens.weight) @@ -425,6 +422,6 @@ def post_patch(model): for _ in range(3): gc.collect() torch.cuda.empty_cache() - return model + return model, tokenizer pass pass diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index bf40ea8a27..54d8f628cb 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -490,10 +490,7 @@ def pre_patch(): @staticmethod - def post_patch(model): - # Patch model for Gemma - layers = model.model.layers - + def post_patch(model, tokenizer): # Torch.compile fails on embedding matrix?? # Workaround randomnly fixes it for torch versions < 2.2 model.model.embed_tokens = torch.nn.Embedding.from_pretrained(model.model.embed_tokens.weight) @@ -576,6 +573,6 @@ def post_patch(model): for _ in range(3): gc.collect() torch.cuda.empty_cache() - return model + return model, tokenizer pass pass diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 9c9ea53752..044ea6e244 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1621,7 +1621,7 @@ def from_pretrained( ) model, tokenizer = patch_tokenizer(model, tokenizer) - model = model_patcher.post_patch(model) + model, tokenizer = model_patcher.post_patch(model, tokenizer) # Patch up QKV / O and MLP for idx, layer in enumerate(model.model.layers): @@ -1827,27 +1827,63 @@ def from_pretrained( @staticmethod - def post_patch(model): - # Patch model - layers = model.model.layers - + def post_patch(model, tokenizer): # Torch.compile fails on embedding matrix?? - # Workaround randomnly fixes it for torch versions < 2. - model.set_input_embeddings(torch.nn.Embedding.from_pretrained(model.get_input_embeddings().weight)) + try: old_input_embedding = model.get_input_embeddings ().weight + except: return model, tokenizer + + # Maybe not all models have a lm_head? + try: old_output_embedding = model.get_output_embeddings().weight + except: old_output_embedding = torch.zeros(0) + + # Check for tied weights as well + is_tied = old_input_embedding.data_ptr() == old_output_embedding.data_ptr() + + # Check pad token's id -> we need to expand the embedding + if len(tokenizer) > old_input_embedding.shape[0]: + # Workaround randomnly fixes it for torch versions < 2. + requires_grad = old_input_embedding.requires_grad + old_input_embedding.requires_grad_(False) + old_input_embedding.resize_(len(tokenizer), old_input_embedding.shape[1]) + old_input_embedding.requires_grad_(requires_grad) + + # Fix up all vocab sizes + current_model = model + while hasattr(model, "model") and hasattr(model, "config"): + if hasattr(model.config, "vocab_size"): + current_model.config.update({"vocab_size" : len(tokenizer)}) + current_model = current_model.model + if hasattr(model, "model") and hasattr(model, "config"): + if hasattr(model.config, "vocab_size"): + current_model.config.update({"vocab_size" : len(tokenizer)}) + pass + pass + + model.set_input_embeddings( + torch.nn.Embedding.from_pretrained( + old_input_embedding, + padding_idx = getattr(model.config, "pad_token_id", None), + ) + ) model.config.update({"unsloth_version" : __version__}) # We also do this for the lm_head - lm_head = torch.nn.Linear(1, 1, bias = None) - del lm_head.weight - lm_head.weight = model.get_output_embeddings().weight - lm_head.in_features = lm_head.weight.shape[1] - lm_head.out_features = lm_head.weight.shape[0] - model.lm_head = lm_head + if old_output_embedding.numel() != 0: + requires_grad = old_output_embedding.requires_grad + lm_head = torch.nn.Linear(1, 1, bias = None) + del lm_head.weight + lm_head.weight = old_output_embedding if not is_tied else old_input_embedding + lm_head.in_features = lm_head.weight.shape[1] + lm_head.out_features = lm_head.weight.shape[0] + lm_head.weight.requires_grad_(requires_grad) + model.lm_head = lm_head + correct_dtype = lm_head.weight.dtype + else: + correct_dtype = old_input_embedding.dtype + pass # Also patch all dtypes - BnB seems to not allocate the correct type? # BnB default dtype seems to be float16! - correct_dtype = lm_head.weight.dtype - for name, module in model.named_modules(): if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)): weight = module.weight @@ -1883,7 +1919,7 @@ def post_patch(model): for _ in range(3): gc.collect() torch.cuda.empty_cache() - return model + return model, tokenizer pass From 5286f1972560ef1b106e07e22132cd771143245f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 27 Oct 2024 19:08:14 -0700 Subject: [PATCH 031/209] Update llama.py --- unsloth/models/llama.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 044ea6e244..a53f52fe27 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1869,14 +1869,19 @@ def post_patch(model, tokenizer): # We also do this for the lm_head if old_output_embedding.numel() != 0: + requires_grad = old_output_embedding.requires_grad lm_head = torch.nn.Linear(1, 1, bias = None) del lm_head.weight + lm_head.weight = old_output_embedding if not is_tied else old_input_embedding lm_head.in_features = lm_head.weight.shape[1] lm_head.out_features = lm_head.weight.shape[0] + lm_head.weight.requires_grad_(requires_grad) - model.lm_head = lm_head + model.set_output_embeddings(lm_head) + if hasattr(model, "lm_head"): model.lm_head = lm_head + correct_dtype = lm_head.weight.dtype else: correct_dtype = old_input_embedding.dtype From 02437a839105c5556f672b122d033f4de22fe095 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 27 Oct 2024 19:09:27 -0700 Subject: [PATCH 032/209] Typo --- unsloth/models/_utils.py | 8 ++++---- unsloth/models/llama.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 46a5f45cfa..b2f4b5c66e 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -613,10 +613,10 @@ def patch_tokenizer(model, tokenizer): # Edit all config with new pad token current_model = model - while hasattr(model, "model") and hasattr(model, "config"): + while hasattr(current_model, "model") and hasattr(current_model, "config"): current_model.config.update({"pad_token_id" : tokenizer.pad_token_id}) current_model = current_model.model - if hasattr(model, "model") and hasattr(model, "config"): + if hasattr(current_model, "model") and hasattr(current_model, "config"): current_model.config.update({"pad_token_id" : tokenizer.pad_token_id}) pass @@ -630,10 +630,10 @@ def patch_tokenizer(model, tokenizer): # Edit all config with new pad token current_model = model - while hasattr(model, "model") and hasattr(model, "config"): + while hasattr(current_model, "model") and hasattr(current_model, "config"): current_model.config.update({"pad_token_id" : tokenizer.pad_token_id}) current_model = model - if hasattr(model, "model") and hasattr(model, "config"): + if hasattr(current_model, "model") and hasattr(current_model, "config"): current_model.config.update({"pad_token_id" : tokenizer.pad_token_id}) pass diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index a53f52fe27..65e2d773e8 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1849,12 +1849,12 @@ def post_patch(model, tokenizer): # Fix up all vocab sizes current_model = model - while hasattr(model, "model") and hasattr(model, "config"): - if hasattr(model.config, "vocab_size"): + while hasattr(current_model, "model") and hasattr(current_model, "config"): + if hasattr(current_model.config, "vocab_size"): current_model.config.update({"vocab_size" : len(tokenizer)}) current_model = current_model.model - if hasattr(model, "model") and hasattr(model, "config"): - if hasattr(model.config, "vocab_size"): + if hasattr(current_model, "model") and hasattr(current_model, "config"): + if hasattr(current_model.config, "vocab_size"): current_model.config.update({"vocab_size" : len(tokenizer)}) pass pass From 9d07be077b3355b55dcf93098d0afe2591e67750 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 27 Oct 2024 22:10:59 -0700 Subject: [PATCH 033/209] ignored labels --- unsloth/models/gemma.py | 13 ++++++++++++- unsloth/models/gemma2.py | 13 ++++++++++++- unsloth/models/llama.py | 27 +++++++++++++-------------- 3 files changed, 37 insertions(+), 16 deletions(-) diff --git a/unsloth/models/gemma.py b/unsloth/models/gemma.py index 1ec116b2ea..095c8cdd6b 100644 --- a/unsloth/models/gemma.py +++ b/unsloth/models/gemma.py @@ -339,7 +339,18 @@ def pre_patch(): @staticmethod - def post_patch(model, tokenizer): + def post_patch(model, tokenizer, max_seq_length): + # Add max_seq_length to all modules + extra_ignored_labels = torch.full((max_seq_length, 1), -100, device = "cuda:0") + internal_model = model + while hasattr(internal_model, "model"): + internal_model.max_seq_length = max_seq_length + internal_model.extra_ignored_labels = extra_ignored_labels + internal_model = internal_model.model + pass + internal_model.max_seq_length = max_seq_length + internal_model.extra_ignored_labels = extra_ignored_labels + # Torch.compile fails on embedding matrix?? # Workaround randomnly fixes it for torch versions < 2.2 model.model.embed_tokens = torch.nn.Embedding.from_pretrained(model.model.embed_tokens.weight) diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index 54d8f628cb..231d8f2661 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -490,7 +490,18 @@ def pre_patch(): @staticmethod - def post_patch(model, tokenizer): + def post_patch(model, tokenizer, max_seq_length): + # Add max_seq_length to all modules + extra_ignored_labels = torch.full((max_seq_length, 1), -100, device = "cuda:0") + internal_model = model + while hasattr(internal_model, "model"): + internal_model.max_seq_length = max_seq_length + internal_model.extra_ignored_labels = extra_ignored_labels + internal_model = internal_model.model + pass + internal_model.max_seq_length = max_seq_length + internal_model.extra_ignored_labels = extra_ignored_labels + # Torch.compile fails on embedding matrix?? # Workaround randomnly fixes it for torch versions < 2.2 model.model.embed_tokens = torch.nn.Embedding.from_pretrained(model.model.embed_tokens.weight) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 65e2d773e8..4712d9ca05 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1621,7 +1621,7 @@ def from_pretrained( ) model, tokenizer = patch_tokenizer(model, tokenizer) - model, tokenizer = model_patcher.post_patch(model, tokenizer) + model, tokenizer = model_patcher.post_patch(model, tokenizer, max_position_embeddings) # Patch up QKV / O and MLP for idx, layer in enumerate(model.model.layers): @@ -1827,7 +1827,18 @@ def from_pretrained( @staticmethod - def post_patch(model, tokenizer): + def post_patch(model, tokenizer, max_seq_length): + # Add max_seq_length to all modules + extra_ignored_labels = torch.full((max_seq_length, 1), -100, device = "cuda:0") + internal_model = model + while hasattr(internal_model, "model"): + internal_model.max_seq_length = max_seq_length + internal_model.extra_ignored_labels = extra_ignored_labels + internal_model = internal_model.model + pass + internal_model.max_seq_length = max_seq_length + internal_model.extra_ignored_labels = extra_ignored_labels + # Torch.compile fails on embedding matrix?? try: old_input_embedding = model.get_input_embeddings ().weight except: return model, tokenizer @@ -2459,18 +2470,6 @@ def patch_peft_model( ) patch_saving_functions(model) - # Patch cross entropy loss labels - # Fixes https://github.com/unslothai/unsloth/issues/10 - max_seq_length = model.max_seq_length - extra_ignored_labels = torch.full((max_seq_length, 1), -100, device = "cuda:0") - model.model.extra_ignored_labels = extra_ignored_labels - internal_model = model - while hasattr(internal_model, "model"): - internal_model.max_seq_length = max_seq_length - internal_model = internal_model.model - pass - internal_model.max_seq_length = max_seq_length - # Patch tokenizer to pad to the right internal_model = model while hasattr(internal_model, "model"): From a8b37a320d3ed72fceff9c08dd1e534eaad703fa Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 27 Oct 2024 22:18:05 -0700 Subject: [PATCH 034/209] Revert "ignored labels" This reverts commit 9d07be077b3355b55dcf93098d0afe2591e67750. --- unsloth/models/gemma.py | 13 +------------ unsloth/models/gemma2.py | 13 +------------ unsloth/models/llama.py | 27 ++++++++++++++------------- 3 files changed, 16 insertions(+), 37 deletions(-) diff --git a/unsloth/models/gemma.py b/unsloth/models/gemma.py index 095c8cdd6b..1ec116b2ea 100644 --- a/unsloth/models/gemma.py +++ b/unsloth/models/gemma.py @@ -339,18 +339,7 @@ def pre_patch(): @staticmethod - def post_patch(model, tokenizer, max_seq_length): - # Add max_seq_length to all modules - extra_ignored_labels = torch.full((max_seq_length, 1), -100, device = "cuda:0") - internal_model = model - while hasattr(internal_model, "model"): - internal_model.max_seq_length = max_seq_length - internal_model.extra_ignored_labels = extra_ignored_labels - internal_model = internal_model.model - pass - internal_model.max_seq_length = max_seq_length - internal_model.extra_ignored_labels = extra_ignored_labels - + def post_patch(model, tokenizer): # Torch.compile fails on embedding matrix?? # Workaround randomnly fixes it for torch versions < 2.2 model.model.embed_tokens = torch.nn.Embedding.from_pretrained(model.model.embed_tokens.weight) diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index 231d8f2661..54d8f628cb 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -490,18 +490,7 @@ def pre_patch(): @staticmethod - def post_patch(model, tokenizer, max_seq_length): - # Add max_seq_length to all modules - extra_ignored_labels = torch.full((max_seq_length, 1), -100, device = "cuda:0") - internal_model = model - while hasattr(internal_model, "model"): - internal_model.max_seq_length = max_seq_length - internal_model.extra_ignored_labels = extra_ignored_labels - internal_model = internal_model.model - pass - internal_model.max_seq_length = max_seq_length - internal_model.extra_ignored_labels = extra_ignored_labels - + def post_patch(model, tokenizer): # Torch.compile fails on embedding matrix?? # Workaround randomnly fixes it for torch versions < 2.2 model.model.embed_tokens = torch.nn.Embedding.from_pretrained(model.model.embed_tokens.weight) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 4712d9ca05..65e2d773e8 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1621,7 +1621,7 @@ def from_pretrained( ) model, tokenizer = patch_tokenizer(model, tokenizer) - model, tokenizer = model_patcher.post_patch(model, tokenizer, max_position_embeddings) + model, tokenizer = model_patcher.post_patch(model, tokenizer) # Patch up QKV / O and MLP for idx, layer in enumerate(model.model.layers): @@ -1827,18 +1827,7 @@ def from_pretrained( @staticmethod - def post_patch(model, tokenizer, max_seq_length): - # Add max_seq_length to all modules - extra_ignored_labels = torch.full((max_seq_length, 1), -100, device = "cuda:0") - internal_model = model - while hasattr(internal_model, "model"): - internal_model.max_seq_length = max_seq_length - internal_model.extra_ignored_labels = extra_ignored_labels - internal_model = internal_model.model - pass - internal_model.max_seq_length = max_seq_length - internal_model.extra_ignored_labels = extra_ignored_labels - + def post_patch(model, tokenizer): # Torch.compile fails on embedding matrix?? try: old_input_embedding = model.get_input_embeddings ().weight except: return model, tokenizer @@ -2470,6 +2459,18 @@ def patch_peft_model( ) patch_saving_functions(model) + # Patch cross entropy loss labels + # Fixes https://github.com/unslothai/unsloth/issues/10 + max_seq_length = model.max_seq_length + extra_ignored_labels = torch.full((max_seq_length, 1), -100, device = "cuda:0") + model.model.extra_ignored_labels = extra_ignored_labels + internal_model = model + while hasattr(internal_model, "model"): + internal_model.max_seq_length = max_seq_length + internal_model = internal_model.model + pass + internal_model.max_seq_length = max_seq_length + # Patch tokenizer to pad to the right internal_model = model while hasattr(internal_model, "model"): From 2dfdba3493b8be24b054608a847a85e99cc2f253 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 28 Oct 2024 01:10:23 -0700 Subject: [PATCH 035/209] More patching --- unsloth/kernels/__init__.py | 2 ++ unsloth/kernels/cross_entropy_loss.py | 49 +++++++++++++++++++++++++++ unsloth/models/_utils.py | 45 +++++++++++++++++++++--- 3 files changed, 92 insertions(+), 4 deletions(-) diff --git a/unsloth/kernels/__init__.py b/unsloth/kernels/__init__.py index 3e55332c80..6357ddaf87 100644 --- a/unsloth/kernels/__init__.py +++ b/unsloth/kernels/__init__.py @@ -16,6 +16,8 @@ fast_cross_entropy_loss, patch_llama_for_causal_lm, unpatch_llama_for_causal_lm, + patch_transformers_losses, + patch_loss_function, ) from .rms_layernorm import ( fast_rms_layernorm, diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index f2377d55cc..7ec1c258a9 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -470,3 +470,52 @@ def unpatch_llama_for_causal_lm(): transformers.models.llama.modeling_llama.LlamaForCausalLM = LlamaForCausalLM return pass + + +@torch._disable_dynamo +def UnslothForCausalLMLoss( + logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs +): + shift_logits = logits + shift_labels = torch.empty_like(labels) + shift_labels[..., :-1] = labels[..., 1:] + shift_labels[..., -1] = -100 + loss = fast_cross_entropy_loss( + logits = shift_logits, + labels = shift_labels, + n_items = num_items_in_batch, + ) + return loss +pass + + +def patch_transformers_losses(): + import re + try: + import transformers.loss.loss_utils + except: + logger.warning_once("Unsloth: Cannot patch loss functions - update transformers for faster modules!") + + import transformers.modeling_utils + LOSS_MAPPING = transformers.loss.loss_utils.LOSS_MAPPING + LOSS_MAPPING["ForCausalLM"] = UnslothForCausalLMLoss + + # Remove @property and @lru_cache + if hasattr(transformers.modeling_utils.PreTrainedModel.loss_function, "fget"): + transformers.modeling_utils.PreTrainedModel.loss_function = \ + transformers.modeling_utils.PreTrainedModel.loss_function.fget.__wrapped__ + pass +pass + + +def patch_loss_function(model): + try: + # model.loss_function starts as a dict to a loss fx + # We invoke it to save it + model.loss_function = model.loss_function() + except: + # Failed means we already invoked it, and we need args to the loss fx + pass + pass + return model +pass diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index b2f4b5c66e..bd40fbd2eb 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -44,6 +44,8 @@ "patch_gradient_checkpointing", "unpatch_gradient_checkpointing", "patch_gradient_accumulation_fix", + "patch_compiling_bitsandbytes", + "patch_regional_compilation", ] import torch @@ -683,8 +685,19 @@ def patch_tokenizer(model, tokenizer): ) pass pass + +# Also disable compiling on bitsandbytes +def patch_compiling_bitsandbytes(): + import peft.tuners.lora.bnb + peft.tuners.lora.bnb.Linear4bit.forward = \ + torch._disable_dynamo(peft.tuners.lora.bnb.Linear4bit.forward) + peft.tuners.lora.bnb.Linear8bit.forward = \ + torch._disable_dynamo(peft.tuners.lora.bnb.Linear8bit.forward) + return +pass # ============================================= + import psutil def _get_statistics(statistics = None, force_download = True): # We log some basic stats about which environment is being used. @@ -896,15 +909,39 @@ def unsloth_offloaded_gradient_checkpoint(function, *args, use_reentrant = None, return Unsloth_Offloaded_Gradient_Checkpointer.apply(function, *args) pass - import torch.utils -old_checkpoint = torch.utils.checkpoint def patch_gradient_checkpointing(): - torch.utils.checkpoint = unsloth_offloaded_gradient_checkpoint + if torch.utils.checkpoint.checkpoint.__name__ == "unsloth_offloaded_gradient_checkpoint": return + torch.utils.checkpoint._old_checkpoint = torch.utils.checkpoint.checkpoint + torch.utils.checkpoint.checkpoint = unsloth_offloaded_gradient_checkpoint pass def unpatch_gradient_checkpointing(): - torch.utils.checkpoint = old_checkpoint + if hasattr(torch.utils.checkpoint, "_old_checkpoint"): + torch.utils.checkpoint.checkpoint = torch.utils.checkpoint._old_checkpoint + del torch.utils.checkpoint._old_checkpoint + pass +pass + + +# ============================================= +# Regional torch 2.5 Recompilation - weirdly very slow?? +def patch_regional_compilation(): + if torch.nn.ModuleList.__name__ == "UnslothModuleList": return + # Only works for torch 2.5 + if Version(torch.__version__) < Version("2.5.0"): return + + old_module_list = torch.nn.ModuleList + + def UnslothModuleList(*args, **kwargs): + if len(args) == 1 and len(kwargs) == 0 and type(args[0]) is list: + args = [old_module_list([torch.compile(x, dynamic = True, options = torch_compile_options, fullgraph = False) for x in args[0]])] + return old_module_list(*args, **kwargs) + pass + UnslothModuleList.__doc__ = old_module_list.__doc__ + + torch.nn.ModuleList = UnslothModuleList + return pass From 5541ab48fe435612b1a14078f3181faa57db6e5e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 28 Oct 2024 01:12:33 -0700 Subject: [PATCH 036/209] Update _utils.py --- unsloth/models/_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index bd40fbd2eb..692f48488c 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -691,8 +691,8 @@ def patch_compiling_bitsandbytes(): import peft.tuners.lora.bnb peft.tuners.lora.bnb.Linear4bit.forward = \ torch._disable_dynamo(peft.tuners.lora.bnb.Linear4bit.forward) - peft.tuners.lora.bnb.Linear8bit.forward = \ - torch._disable_dynamo(peft.tuners.lora.bnb.Linear8bit.forward) + peft.tuners.lora.bnb.Linear8bitLt.forward = \ + torch._disable_dynamo(peft.tuners.lora.bnb.Linear8bitLt.forward) return pass # ============================================= From c6e9af2e5b69abc8cb332abef3eb101b0d33c63e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 28 Oct 2024 10:41:31 -0700 Subject: [PATCH 037/209] Update _utils.py --- unsloth/models/_utils.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 692f48488c..97b2ea7b7b 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -688,11 +688,15 @@ def patch_tokenizer(model, tokenizer): # Also disable compiling on bitsandbytes def patch_compiling_bitsandbytes(): - import peft.tuners.lora.bnb - peft.tuners.lora.bnb.Linear4bit.forward = \ - torch._disable_dynamo(peft.tuners.lora.bnb.Linear4bit.forward) - peft.tuners.lora.bnb.Linear8bitLt.forward = \ - torch._disable_dynamo(peft.tuners.lora.bnb.Linear8bitLt.forward) + # import peft.tuners.lora.bnb + # peft.tuners.lora.bnb.Linear4bit.forward = \ + # torch._disable_dynamo(peft.tuners.lora.bnb.Linear4bit.forward) + # peft.tuners.lora.bnb.Linear8bitLt.forward = \ + # torch._disable_dynamo(peft.tuners.lora.bnb.Linear8bitLt.forward) + # return + import bitsandbytes.nn.modules + bitsandbytes.nn.modules.Linear4bit.forward = \ + torch._disable_dynamo(bitsandbytes.nn.modules.Linear4bit.forward) return pass # ============================================= From cac56d112b08befe1355deabbf6c856230eb9d5d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 28 Oct 2024 14:30:06 -0700 Subject: [PATCH 038/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 7ec1c258a9..4ff4ec152a 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -349,7 +349,7 @@ def backward(ctx, dlosses): pass -@torch._disable_dynamo +# @torch._disable_dynamo def fast_cross_entropy_loss( logits, labels, @@ -472,7 +472,7 @@ def unpatch_llama_for_causal_lm(): pass -@torch._disable_dynamo +# @torch._disable_dynamo def UnslothForCausalLMLoss( logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs ): @@ -495,7 +495,9 @@ def patch_transformers_losses(): import transformers.loss.loss_utils except: logger.warning_once("Unsloth: Cannot patch loss functions - update transformers for faster modules!") - + return + pass + import transformers.modeling_utils LOSS_MAPPING = transformers.loss.loss_utils.LOSS_MAPPING LOSS_MAPPING["ForCausalLM"] = UnslothForCausalLMLoss From 5ee1189657fdc50b79c79a5659fb07e73f10de59 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 28 Oct 2024 14:47:11 -0700 Subject: [PATCH 039/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 4ff4ec152a..2db59bbae7 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -25,7 +25,8 @@ }) @triton.jit def _cross_entropy_forward( - logits_ptr, logits_row_stride, + logits_ptr, + logits_row_stride : tl.constexpr(tl.int64), loss_ptr, logsumexp_ptr, labels_ptr, @@ -57,7 +58,7 @@ def _cross_entropy_forward( This ensures exp(x - max(x))'s maximum is 1 as exp(0) = 1. """ row_idx = tl.program_id(0) - logits_ptr += row_idx * logits_row_stride.to(tl.int64) + logits_ptr += row_idx * logits_row_stride loss_ptr += row_idx logsumexp_ptr += row_idx labels_ptr += row_idx @@ -97,7 +98,8 @@ def _cross_entropy_forward( }) @triton.jit def _chunked_cross_entropy_forward( - logits_ptr, logits_row_stride, + logits_ptr, + logits_row_stride : tl.constexpr(tl.int64), loss_ptr, logsumexp_ptr, labels_ptr, @@ -135,7 +137,7 @@ def _chunked_cross_entropy_forward( """ row_idx = tl.program_id(0) chunk_idx = tl.program_id(1) - logits_ptr += row_idx * logits_row_stride.to(tl.int64) + logits_ptr += row_idx * logits_row_stride loss_ptr += row_idx logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx labels_ptr += row_idx @@ -179,7 +181,8 @@ def _chunked_cross_entropy_forward( }) @triton.jit def _cross_entropy_backward( - logits_ptr, logits_row_stride, + logits_ptr, + logits_row_stride : tl.constexpr(tl.int64), dloss_ptr, dloss_row_stride, logsumexp_ptr, labels_ptr, @@ -208,7 +211,7 @@ def _cross_entropy_backward( row_idx = tl.program_id(0) block_idx = tl.program_id(1) - logits_ptr += row_idx * logits_row_stride.to(tl.int64) + logits_ptr += row_idx * logits_row_stride dloss_ptr += row_idx * dloss_row_stride col_offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = col_offsets < VOCAB_SIZE @@ -497,7 +500,7 @@ def patch_transformers_losses(): logger.warning_once("Unsloth: Cannot patch loss functions - update transformers for faster modules!") return pass - + import transformers.modeling_utils LOSS_MAPPING = transformers.loss.loss_utils.LOSS_MAPPING LOSS_MAPPING["ForCausalLM"] = UnslothForCausalLMLoss From 85a5f6098a1bac8af5a3482d2a1569069111c84d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 28 Oct 2024 15:01:04 -0700 Subject: [PATCH 040/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 2db59bbae7..debd037b64 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -269,7 +269,7 @@ def forward(ctx, logits, labels, logit_softcapping = 0, logit_scaling = 0): div, mod = divmod(vocab_size, MAX_FUSED_SIZE) n_chunks = div + (mod != 0) - losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0") + losses = torch.empty(n_rows, dtype = torch.float32, device = logits.device) DO_SOFTCAPPING = (logit_softcapping != 0) DO_LOGIT_SCALING = (logit_scaling != 0) @@ -277,7 +277,7 @@ def forward(ctx, logits, labels, logit_softcapping = 0, logit_scaling = 0): if n_chunks == 1: # For small vocabs <= 65336 like Llama, Mistral BLOCK_SIZE, num_warps = calculate_settings(vocab_size) - logsumexp = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0") + logsumexp = torch.empty(n_rows, dtype = torch.float32, device = logits.device) _cross_entropy_forward[(n_rows,)]( logits, logits.stride(0), @@ -294,7 +294,7 @@ def forward(ctx, logits, labels, logit_softcapping = 0, logit_scaling = 0): ) else: # For large vocabs > 65336 like Gemma 256K - logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = "cuda:0") + logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = logits.device) _chunked_cross_entropy_forward[(n_rows, n_chunks,)]( logits, logits.stride(0), From 20e38eda6b6ec2a97e8cdb62c84b6747179c659c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Oct 2024 00:43:03 -0700 Subject: [PATCH 041/209] Feat/all tmp (#1219) * Update save.py Check whether path is in /tmp dir for Kaggle environment * Update save.py Move temporary_location to /tmp in Kaggle * Enhance Kaggle environment support in save and tokenizer utilities --------- Co-authored-by: dendarrion <37800703+dendarrion@users.noreply.github.com> Co-authored-by: Erland366 --- unsloth/save.py | 66 ++++++++++++++++++++++++-------------- unsloth/tokenizer_utils.py | 7 +++- 2 files changed, 48 insertions(+), 25 deletions(-) diff --git a/unsloth/save.py b/unsloth/save.py index ccda79aeee..b4c6b499cf 100644 --- a/unsloth/save.py +++ b/unsloth/save.py @@ -49,6 +49,7 @@ keynames = "\n" + "\n".join(os.environ.keys()) IS_COLAB_ENVIRONMENT = "\nCOLAB_" in keynames IS_KAGGLE_ENVIRONMENT = "\nKAGGLE_" in keynames +KAGGLE_TMP = "/tmp" del keynames # Weights @@ -447,13 +448,20 @@ def unsloth_save_model( if push_to_hub and "/" in save_directory: # +1 solves absolute path issues - username = save_directory[:save_directory.find("/")] - new_save_directory = save_directory[save_directory.find("/")+1:] - - logger.warning_once( - f"Unsloth: You are pushing to hub, but you passed your HF username = {username}.\n"\ - f"We shall truncate {save_directory} to {new_save_directory}" - ) + new_save_directory = save_directory + username = new_save_directory[:new_save_directory.find("/")] + new_save_directory = new_save_directory[new_save_directory.find("/")+1:] + if IS_KAGGLE_ENVIRONMENT: + new_save_directory = os.path.join(KAGGLE_TMP, new_save_directory[new_save_directory.find("/")+1:]) + logger.warning_once( + "Unsloth: You are pushing to hub in Kaggle environment.\n"\ + f"To save memory, we shall move {save_directory} to {new_save_directory}" + ) + else: + logger.warning_once( + f"Unsloth: You are pushing to hub, but you passed your HF username = {username}.\n"\ + f"We shall truncate {save_directory} to {new_save_directory}" + ) save_pretrained_settings["save_directory"] = new_save_directory tokenizer_save_settings ["save_directory"] = new_save_directory @@ -507,6 +515,10 @@ def unsloth_save_model( f"{round(max_ram/1024/1024/1024, 2)} out of "\ f"{round(psutil.virtual_memory().total/1024/1024/1024, 2)} RAM for saving.") + # Move temporary_location to /tmp in Kaggle + if IS_KAGGLE_ENVIRONMENT: + temporary_location = os.path.join(KAGGLE_TMP, temporary_location) + # Max directory for disk saving if not os.path.exists(temporary_location): os.makedirs(temporary_location) @@ -708,7 +720,7 @@ def unsloth_save_model( print("Done.") if push_to_hub and hasattr(model, "config"): - print(f"Saved merged model to https://huggingface.co/{username}/{save_directory.lstrip('/')}") + print(f"Saved merged model to https://huggingface.co/{username}/{save_directory.lstrip('/').split('/')[-1]}") pass save_pretrained_settings["state_dict"] = None @@ -1108,14 +1120,17 @@ def save_to_gguf( # Check if quantization succeeded! if not os.path.isfile(final_location): if IS_KAGGLE_ENVIRONMENT: - raise RuntimeError( - f"Unsloth: Quantization failed for {final_location}\n"\ - "You are in a Kaggle environment, which might be the reason this is failing.\n"\ - "Kaggle only provides 20GB of disk space. Merging to 16bit for 7b models use 16GB of space.\n"\ - "This means using `model.{save_pretrained/push_to_hub}_merged` works, but\n"\ - "`model.{save_pretrained/push_to_hub}_gguf will use too much disk space.\n"\ - "I suggest you to save the 16bit model first, then use manual llama.cpp conversion." - ) + if not Path(final_location).resolve().is_relative_to(Path('/tmp').resolve()): + raise RuntimeError( + f"Unsloth: Quantization failed for {final_location}\n"\ + "You are in a Kaggle environment, which might be the reason this is failing.\n"\ + "Kaggle only provides 20GB of disk space in the working directory.\n"\ + "Merging to 16bit for 7b models use 16GB of space.\n"\ + "This means using `model.{save_pretrained/push_to_hub}_merged` works, but\n"\ + "`model.{save_pretrained/push_to_hub}_gguf will use too much disk space.\n"\ + "You can try saving it to the `/tmp` directory for larger disk space.\n"\ + "I suggest you to save the 16bit model first, then use manual llama.cpp conversion." + ) else: raise RuntimeError( f"Unsloth: Quantization failed for {final_location}\n"\ @@ -1156,14 +1171,17 @@ def save_to_gguf( # Check if quantization succeeded! if not os.path.isfile(final_location): if IS_KAGGLE_ENVIRONMENT: - raise RuntimeError( - f"Unsloth: Quantization failed for {final_location}\n"\ - "You are in a Kaggle environment, which might be the reason this is failing.\n"\ - "Kaggle only provides 20GB of disk space. Merging to 16bit for 7b models use 16GB of space.\n"\ - "This means using `model.{save_pretrained/push_to_hub}_merged` works, but\n"\ - "`model.{save_pretrained/push_to_hub}_gguf will use too much disk space.\n"\ - "I suggest you to save the 16bit model first, then use manual llama.cpp conversion." - ) + if not Path(final_location).resolve().is_relative_to(Path('/tmp').resolve()): + raise RuntimeError( + f"Unsloth: Quantization failed for {final_location}\n"\ + "You are in a Kaggle environment, which might be the reason this is failing.\n"\ + "Kaggle only provides 20GB of disk space in the working directory.\n"\ + "Merging to 16bit for 7b models use 16GB of space.\n"\ + "This means using `model.{save_pretrained/push_to_hub}_merged` works, but\n"\ + "`model.{save_pretrained/push_to_hub}_gguf will use too much disk space.\n"\ + "You can try saving it to the `/tmp` directory for larger disk space.\n"\ + "I suggest you to save the 16bit model first, then use manual llama.cpp conversion." + ) else: raise RuntimeError( "Unsloth: Quantization failed! You might have to compile llama.cpp yourself, then run this again.\n"\ diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index c05485f902..c639dbf1a0 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -64,6 +64,7 @@ keynames = "\n" + "\n".join(os.environ.keys()) IS_COLAB_ENVIRONMENT = "\nCOLAB_" in keynames IS_KAGGLE_ENVIRONMENT = "\nKAGGLE_" in keynames +KAGGLE_TMP = "/tmp" del keynames @@ -470,8 +471,12 @@ def _load_correct_tokenizer( cache_dir = "huggingface_tokenizers_cache", fix_tokenizer = True, ): - if IS_COLAB_ENVIRONMENT or IS_KAGGLE_ENVIRONMENT: + if IS_COLAB_ENVIRONMENT: cache_dir = cache_dir + elif IS_KAGGLE_ENVIRONMENT: + # /tmp of Kaggle seems has a 80GB limit! + # Let's utilize them + cache_dir = os.path.join(KAGGLE_TMP, cache_dir) else: cache_dir = None pass From 7e1692ace19627e0cff5d8ece58b71a59d78c651 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Oct 2024 13:11:43 -0700 Subject: [PATCH 042/209] Bug fixes --- unsloth/__init__.py | 14 +++++++------- unsloth/kernels/cross_entropy_loss.py | 10 +++++++--- unsloth/models/_utils.py | 6 +++++- unsloth/models/llama.py | 1 + 4 files changed, 20 insertions(+), 11 deletions(-) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 458c2696bc..109e1c6d2f 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -27,13 +27,6 @@ # pass # pass -# Check for unsloth_zoo -try: - import unsloth_zoo -except: - raise ImportError("Unsloth: Please install unsloth_zoo via `pip install unsloth-zoo`") -pass - # Unsloth currently does not work on multi GPU setups - sadly we are a 2 brother team so # enabling it will require much more work, so we have to prioritize. Please understand! # We do have a beta version, which you can contact us about! @@ -165,6 +158,13 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 pass pass +# Check for unsloth_zoo +try: + import unsloth_zoo +except: + raise ImportError("Unsloth: Please install unsloth_zoo via `pip install unsloth-zoo`") +pass + from .models import * from .save import * from .chat_templates import * diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index debd037b64..a2337a14d9 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -17,7 +17,7 @@ import torch from .utils import calculate_settings, MAX_FUSED_SIZE, triton_tanh from transformers.models.llama.modeling_llama import logger - +from packaging.version import Version @triton.heuristics({ "DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING" ], @@ -352,7 +352,6 @@ def backward(ctx, dlosses): pass -# @torch._disable_dynamo def fast_cross_entropy_loss( logits, labels, @@ -380,6 +379,9 @@ def fast_cross_entropy_loss( n_items = torch.count_nonzero(labels != -100) return loss.sum() / n_items pass +if Version(torch.__version__) < Version("2.5.0"): + fast_cross_entropy_loss = torch._disable_dynamo(fast_cross_entropy_loss) +pass from transformers.models.llama.modeling_llama import ( @@ -475,7 +477,6 @@ def unpatch_llama_for_causal_lm(): pass -# @torch._disable_dynamo def UnslothForCausalLMLoss( logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs ): @@ -490,6 +491,9 @@ def UnslothForCausalLMLoss( ) return loss pass +if Version(torch.__version__) < Version("2.5.0"): + UnslothForCausalLMLoss = torch._disable_dynamo(UnslothForCausalLMLoss) +pass def patch_transformers_losses(): diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 97b2ea7b7b..a39bc58db0 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -760,7 +760,9 @@ def get_statistics(): # We log some basic stats about which environment is being used. # We simply download a README.md file from HF - all data is made public. # This is simply so we can check if some envs are broken or not. - # You can disable this by commenting the below out + # You can disable this by setting UNSLOTH_DISABLE_STATISTICS + import os + if "UNSLOTH_DISABLE_STATISTICS" in os.environ: return from huggingface_hub.utils import disable_progress_bars, enable_progress_bars, are_progress_bars_disabled disabled = False if not are_progress_bars_disabled(): @@ -1295,6 +1297,7 @@ def patch_gradient_accumulation_fix(Trainer): # Fixes gradient accumulation import inspect if hasattr(Trainer, "get_batch_samples"): + if Trainer.get_batch_samples.__name__ == "_unsloth_get_batch_samples": return if \ not inspect.getsource(Trainer.get_batch_samples).strip()\ .endswith("return batch_samples, num_items_in_batch"): @@ -1321,6 +1324,7 @@ def patch_gradient_accumulation_fix(Trainer): pass # Also fix up loss scaling ie negate loss *= self.args.gradient_accumulation_steps + if Trainer.training_step.__name__ == "_unsloth_training_step": return if "num_items_in_batch" not in inspect.signature(Trainer.training_step).parameters: return function = inspect.getsource(Trainer.training_step) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 65e2d773e8..c0175bbfaf 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1518,6 +1518,7 @@ def from_pretrained( pass # Return old flag os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = old_hf_transfer + os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" model_patcher.pre_patch() get_statistics() # For debugging - we use a download counter to see if environments are not breaking From 6bef8f1c3cc2e0a97166c92b0d348f0753bc3ceb Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Oct 2024 13:16:00 -0700 Subject: [PATCH 043/209] Update pyproject.toml --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fc9c8256ad..8922cc7c8d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ huggingface = [ "unsloth_zoo", "packaging", "tyro", - "transformers>=4.44.2", + "transformers>=4.46.1", "datasets>=2.16.0", "sentencepiece>=0.2.0", "tqdm", @@ -247,7 +247,7 @@ colab-new = [ "unsloth_zoo", "packaging", "tyro", - "transformers>=4.44.2", + "transformers>=4.46.1", "datasets>=2.16.0", "sentencepiece>=0.2.0", "tqdm", From 9ccbc0ed1fb32107a848a3e11f05b89f65107d18 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Oct 2024 13:35:01 -0700 Subject: [PATCH 044/209] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index a39bc58db0..35b56f26a9 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -910,7 +910,7 @@ def backward(ctx, dY): pass -@torch._disable_dynamo +# @torch._disable_dynamo def unsloth_offloaded_gradient_checkpoint(function, *args, use_reentrant = None, **kwargs): return Unsloth_Offloaded_Gradient_Checkpointer.apply(function, *args) pass From 95ecc5795dcc1de86ac1bca5a55dac6ae2c48f11 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Oct 2024 13:44:05 -0700 Subject: [PATCH 045/209] Update __init__.py --- unsloth/__init__.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 109e1c6d2f..23f54d213e 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -53,6 +53,13 @@ # Reduce VRAM usage by reducing fragmentation os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +# Hugging Face Hub faster downloads (only enable during Colab and Kaggle sessions) +keynames = "\n" + "\n".join(os.environ.keys()) +if "\nCOLAB_" in keynames or "\nKAGGLE_" in keynames: + os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" + print("Hello") +pass + try: import torch except ModuleNotFoundError: @@ -64,12 +71,6 @@ raise exception pass -# Hugging Face Hub faster downloads (only enable during Colab and Kaggle sessions) -keynames = "\n" + "\n".join(os.environ.keys()) -if "\nCOLAB_" in keynames or "\nKAGGLE_" in keynames: - os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" -pass - # We support Pytorch 2 # Fixes https://github.com/unslothai/unsloth/issues/38 torch_version = torch.__version__.split(".") From 5f5fef8075f5df30d2f7b72ac57e167c3354ebe9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Oct 2024 13:47:21 -0700 Subject: [PATCH 046/209] Update __init__.py --- unsloth/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 23f54d213e..91ec460094 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -57,7 +57,6 @@ keynames = "\n" + "\n".join(os.environ.keys()) if "\nCOLAB_" in keynames or "\nKAGGLE_" in keynames: os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" - print("Hello") pass try: From 784dd13da70ed1c7c0f58b506c85f24dc024fa8c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Oct 2024 13:51:10 -0700 Subject: [PATCH 047/209] Update _utils.py --- unsloth/models/_utils.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 35b56f26a9..0539e255ea 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -371,16 +371,16 @@ def is_big_gpu(index): "config.cache_size_limit = 1024", # Flex Attention "config.inline_inbuilt_nn_modules = True", # Torch 2.5 Regional recompilation ] -import torch._inductor.config as config -for _try_compile_argument in torch_compile_arguments: - try: exec(_try_compile_argument) - except: pass -pass -import torch._dynamo.config as config -for _try_dynamo_argument in torch_dynamo_arguments: - try: exec(_try_dynamo_argument) - except: pass -pass +# import torch._inductor.config as config +# for _try_compile_argument in torch_compile_arguments: +# try: exec(_try_compile_argument) +# except: pass +# pass +# import torch._dynamo.config as config +# for _try_dynamo_argument in torch_dynamo_arguments: +# try: exec(_try_dynamo_argument) +# except: pass +# pass torch_compile_options = { "epilogue_fusion" : True, "max_autotune" : True, From 5b75e21a4b8bcda18e4b4d2d99beef041bf7dd3d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Oct 2024 13:54:11 -0700 Subject: [PATCH 048/209] Update _utils.py --- unsloth/models/_utils.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 0539e255ea..a3cb3cda8d 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -366,21 +366,21 @@ def is_big_gpu(index): # Torch dynamo arguments torch_dynamo_arguments = [ "config.accumulated_cache_size_limit = 1024", # Bump up a bit from 256 - "config.suppress_errors = True", # Supress errors for now + # "config.suppress_errors = True", # Supress errors for now "config.do_not_emit_runtime_asserts = True", "config.cache_size_limit = 1024", # Flex Attention "config.inline_inbuilt_nn_modules = True", # Torch 2.5 Regional recompilation ] -# import torch._inductor.config as config -# for _try_compile_argument in torch_compile_arguments: -# try: exec(_try_compile_argument) -# except: pass -# pass -# import torch._dynamo.config as config -# for _try_dynamo_argument in torch_dynamo_arguments: -# try: exec(_try_dynamo_argument) -# except: pass -# pass +import torch._inductor.config as config +for _try_compile_argument in torch_compile_arguments: + try: exec(_try_compile_argument) + except: pass +pass +import torch._dynamo.config as config +for _try_dynamo_argument in torch_dynamo_arguments: + try: exec(_try_dynamo_argument) + except: pass +pass torch_compile_options = { "epilogue_fusion" : True, "max_autotune" : True, From 74ab93c9da51a353b558e16ad521641f37b152b3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Oct 2024 14:00:03 -0700 Subject: [PATCH 049/209] Update _utils.py --- unsloth/models/_utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index a3cb3cda8d..a13723c9fb 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -366,7 +366,7 @@ def is_big_gpu(index): # Torch dynamo arguments torch_dynamo_arguments = [ "config.accumulated_cache_size_limit = 1024", # Bump up a bit from 256 - # "config.suppress_errors = True", # Supress errors for now + "config.suppress_errors = True", # Supress errors for now "config.do_not_emit_runtime_asserts = True", "config.cache_size_limit = 1024", # Flex Attention "config.inline_inbuilt_nn_modules = True", # Torch 2.5 Regional recompilation @@ -376,11 +376,11 @@ def is_big_gpu(index): try: exec(_try_compile_argument) except: pass pass -import torch._dynamo.config as config -for _try_dynamo_argument in torch_dynamo_arguments: - try: exec(_try_dynamo_argument) - except: pass -pass +# import torch._dynamo.config as config +# for _try_dynamo_argument in torch_dynamo_arguments: +# try: exec(_try_dynamo_argument) +# except: pass +# pass torch_compile_options = { "epilogue_fusion" : True, "max_autotune" : True, From 526505c11989d5e02f73d0578922748205092318 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Oct 2024 14:01:49 -0700 Subject: [PATCH 050/209] Update _utils.py --- unsloth/models/_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index a13723c9fb..0539e255ea 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -371,11 +371,11 @@ def is_big_gpu(index): "config.cache_size_limit = 1024", # Flex Attention "config.inline_inbuilt_nn_modules = True", # Torch 2.5 Regional recompilation ] -import torch._inductor.config as config -for _try_compile_argument in torch_compile_arguments: - try: exec(_try_compile_argument) - except: pass -pass +# import torch._inductor.config as config +# for _try_compile_argument in torch_compile_arguments: +# try: exec(_try_compile_argument) +# except: pass +# pass # import torch._dynamo.config as config # for _try_dynamo_argument in torch_dynamo_arguments: # try: exec(_try_dynamo_argument) From 251ba777719a1a8dd403387bcd3f18e18b16024b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Oct 2024 14:57:14 -0700 Subject: [PATCH 051/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 87 ++++++++++++++------------- 1 file changed, 44 insertions(+), 43 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index a2337a14d9..e23f818a11 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -19,23 +19,23 @@ from transformers.models.llama.modeling_llama import logger from packaging.version import Version -@triton.heuristics({ - "DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING" ], - "DO_LOGIT_SCALING": lambda args: args["DO_LOGIT_SCALING"], -}) +# @triton.heuristics({ +# "DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING" ], +# "DO_LOGIT_SCALING": lambda args: args["DO_LOGIT_SCALING"], +# }) @triton.jit def _cross_entropy_forward( - logits_ptr, + logits_ptr : tl.pointer_type, logits_row_stride : tl.constexpr(tl.int64), - loss_ptr, - logsumexp_ptr, - labels_ptr, - VOCAB_SIZE : tl.constexpr, - BLOCK_SIZE : tl.constexpr, - DO_SOFTCAPPING : tl.constexpr, - SOFTCAP : tl.constexpr, - DO_LOGIT_SCALING: tl.constexpr, - LOGIT_SCALE : tl.constexpr, + loss_ptr : tl.pointer_type(tl.float32), + logsumexp_ptr : tl.pointer_type(tl.float32), + labels_ptr : tl.const_pointer_type(tl.int32), + VOCAB_SIZE : tl.constexpr, + BLOCK_SIZE : tl.constexpr, + DO_SOFTCAPPING : tl.constexpr(tl.int1), + SOFTCAP : tl.constexpr(tl.float32), + DO_LOGIT_SCALING : tl.constexpr(tl.int1), + LOGIT_SCALE : tl.constexpr(tl.float32), ): """ Cross Entropy Loss = 1/n sum [ -yi log(Pi) ] @@ -92,24 +92,24 @@ def _cross_entropy_forward( pass -@triton.heuristics({ - "DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING" ], - "DO_LOGIT_SCALING": lambda args: args["DO_LOGIT_SCALING"], -}) +# @triton.heuristics({ +# "DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING" ], +# "DO_LOGIT_SCALING": lambda args: args["DO_LOGIT_SCALING"], +# }) @triton.jit def _chunked_cross_entropy_forward( - logits_ptr, + logits_ptr : tl.const_pointer_type, logits_row_stride : tl.constexpr(tl.int64), - loss_ptr, - logsumexp_ptr, - labels_ptr, - VOCAB_SIZE : tl.constexpr, - N_CHUNKS : tl.constexpr, - BLOCK_SIZE : tl.constexpr, - DO_SOFTCAPPING : tl.constexpr, - SOFTCAP : tl.constexpr, - DO_LOGIT_SCALING: tl.constexpr, - LOGIT_SCALE : tl.constexpr, + loss_ptr : tl.pointer_type(tl.float32), + logsumexp_ptr : tl.pointer_type(tl.float32), + labels_ptr : tl.const_pointer_type(tl.int32), + VOCAB_SIZE : tl.constexpr, + N_CHUNKS : tl.constexpr, + BLOCK_SIZE : tl.constexpr, + DO_SOFTCAPPING : tl.constexpr(tl.int1), + SOFTCAP : tl.constexpr, + DO_LOGIT_SCALING : tl.constexpr(tl.int1), + LOGIT_SCALE : tl.constexpr, ): """ 256K vocab divided in 4 chunks @@ -175,23 +175,24 @@ def _chunked_cross_entropy_forward( pass -@triton.heuristics({ - "DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING" ], - "DO_LOGIT_SCALING": lambda args: args["DO_LOGIT_SCALING"], -}) +# @triton.heuristics({ +# "DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING" ], +# "DO_LOGIT_SCALING": lambda args: args["DO_LOGIT_SCALING"], +# }) @triton.jit def _cross_entropy_backward( - logits_ptr, + logits_ptr : tl.pointer_type, logits_row_stride : tl.constexpr(tl.int64), - dloss_ptr, dloss_row_stride, - logsumexp_ptr, - labels_ptr, - VOCAB_SIZE : tl.constexpr, - BLOCK_SIZE : tl.constexpr, - DO_SOFTCAPPING : tl.constexpr, - SOFTCAP : tl.constexpr, - DO_LOGIT_SCALING: tl.constexpr, - LOGIT_SCALE : tl.constexpr, + dloss_ptr : tl.const_pointer_type(tl.float32), + dloss_row_stride : tl.constexpr, + logsumexp_ptr : tl.const_pointer_type(tl.float32), + labels_ptr : tl.const_pointer_type(tl.int32), + VOCAB_SIZE : tl.constexpr, + BLOCK_SIZE : tl.constexpr, + DO_SOFTCAPPING : tl.constexpr(tl.int1), + SOFTCAP : tl.constexpr, + DO_LOGIT_SCALING : tl.constexpr(tl.int1), + LOGIT_SCALE : tl.constexpr, ): """ CE_i = -y log(P) = y * (log[sum(exp(x))] - x) From 530c4958e8d7394862973ca6440a317a416ebd0f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Oct 2024 15:33:37 -0700 Subject: [PATCH 052/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 70 +++++++++++++-------------- 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index e23f818a11..fd34b9a4df 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -25,17 +25,17 @@ # }) @triton.jit def _cross_entropy_forward( - logits_ptr : tl.pointer_type, - logits_row_stride : tl.constexpr(tl.int64), - loss_ptr : tl.pointer_type(tl.float32), - logsumexp_ptr : tl.pointer_type(tl.float32), - labels_ptr : tl.const_pointer_type(tl.int32), - VOCAB_SIZE : tl.constexpr, - BLOCK_SIZE : tl.constexpr, - DO_SOFTCAPPING : tl.constexpr(tl.int1), - SOFTCAP : tl.constexpr(tl.float32), - DO_LOGIT_SCALING : tl.constexpr(tl.int1), - LOGIT_SCALE : tl.constexpr(tl.float32), + logits_ptr , + logits_row_stride , + loss_ptr , + logsumexp_ptr , + labels_ptr , + VOCAB_SIZE , + BLOCK_SIZE , + DO_SOFTCAPPING , + SOFTCAP , + DO_LOGIT_SCALING , + LOGIT_SCALE , ): """ Cross Entropy Loss = 1/n sum [ -yi log(Pi) ] @@ -98,18 +98,18 @@ def _cross_entropy_forward( # }) @triton.jit def _chunked_cross_entropy_forward( - logits_ptr : tl.const_pointer_type, - logits_row_stride : tl.constexpr(tl.int64), - loss_ptr : tl.pointer_type(tl.float32), - logsumexp_ptr : tl.pointer_type(tl.float32), - labels_ptr : tl.const_pointer_type(tl.int32), - VOCAB_SIZE : tl.constexpr, - N_CHUNKS : tl.constexpr, - BLOCK_SIZE : tl.constexpr, - DO_SOFTCAPPING : tl.constexpr(tl.int1), - SOFTCAP : tl.constexpr, - DO_LOGIT_SCALING : tl.constexpr(tl.int1), - LOGIT_SCALE : tl.constexpr, + logits_ptr , + logits_row_stride , + loss_ptr , + logsumexp_ptr , + labels_ptr , + VOCAB_SIZE , + N_CHUNKS , + BLOCK_SIZE , + DO_SOFTCAPPING , + SOFTCAP , + DO_LOGIT_SCALING , + LOGIT_SCALE , ): """ 256K vocab divided in 4 chunks @@ -181,18 +181,18 @@ def _chunked_cross_entropy_forward( # }) @triton.jit def _cross_entropy_backward( - logits_ptr : tl.pointer_type, - logits_row_stride : tl.constexpr(tl.int64), - dloss_ptr : tl.const_pointer_type(tl.float32), - dloss_row_stride : tl.constexpr, - logsumexp_ptr : tl.const_pointer_type(tl.float32), - labels_ptr : tl.const_pointer_type(tl.int32), - VOCAB_SIZE : tl.constexpr, - BLOCK_SIZE : tl.constexpr, - DO_SOFTCAPPING : tl.constexpr(tl.int1), - SOFTCAP : tl.constexpr, - DO_LOGIT_SCALING : tl.constexpr(tl.int1), - LOGIT_SCALE : tl.constexpr, + logits_ptr , + logits_row_stride , + dloss_ptr , + dloss_row_stride , + logsumexp_ptr , + labels_ptr , + VOCAB_SIZE , + BLOCK_SIZE , + DO_SOFTCAPPING , + SOFTCAP , + DO_LOGIT_SCALING , + LOGIT_SCALE , ): """ CE_i = -y log(P) = y * (log[sum(exp(x))] - x) From 07394c34368efbbc338b3c664d34a8540c308ac4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Oct 2024 15:46:09 -0700 Subject: [PATCH 053/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index fd34b9a4df..20b57b30b7 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -31,7 +31,7 @@ def _cross_entropy_forward( logsumexp_ptr , labels_ptr , VOCAB_SIZE , - BLOCK_SIZE , + BLOCK_SIZE : tl.constexpr, DO_SOFTCAPPING , SOFTCAP , DO_LOGIT_SCALING , @@ -105,7 +105,7 @@ def _chunked_cross_entropy_forward( labels_ptr , VOCAB_SIZE , N_CHUNKS , - BLOCK_SIZE , + BLOCK_SIZE : tl.constexpr, DO_SOFTCAPPING , SOFTCAP , DO_LOGIT_SCALING , @@ -188,7 +188,7 @@ def _cross_entropy_backward( logsumexp_ptr , labels_ptr , VOCAB_SIZE , - BLOCK_SIZE , + BLOCK_SIZE : tl.constexpr, DO_SOFTCAPPING , SOFTCAP , DO_LOGIT_SCALING , From 6d7004b5ca5f1dd9e823d6c3ee26855ac2e9ab6f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Oct 2024 16:37:37 -0700 Subject: [PATCH 054/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 20b57b30b7..c5f054a43f 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -223,7 +223,7 @@ def _cross_entropy_backward( else: dloss = 0.0 - x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")) + x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32) # Do logit scaling for Cohere if DO_LOGIT_SCALING: @@ -239,7 +239,7 @@ def _cross_entropy_backward( pass logsumexp = tl.load(logsumexp_ptr + row_idx) - y = tl.exp(x.to(tl.float32) - logsumexp) + y = tl.exp(x - logsumexp) y = tl.where( col_offsets == label_idx, y - 1.0, # exp(x - logsumexp) - 1 From d86b20a50fded8c4a6b6c73e079a0600cf0d1d5a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Oct 2024 16:46:33 -0700 Subject: [PATCH 055/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index c5f054a43f..474c291ed8 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -232,6 +232,7 @@ def _cross_entropy_backward( pass # Do logit softcapping for Gemma 2: t * tanh(1/t * x) + partial = 0.0 if DO_SOFTCAPPING: # d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x) partial = triton_tanh(x / SOFTCAP) From 9920950b7fb7d8116d007547ad6aef8027d0f950 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Oct 2024 16:50:56 -0700 Subject: [PATCH 056/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 474c291ed8..9beb2da25c 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -353,7 +353,7 @@ def backward(ctx, dlosses): pass pass - +@torch._disable_dynamo def fast_cross_entropy_loss( logits, labels, From 9f926ced28d377e6c8616f4021570792995c6321 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Oct 2024 16:53:04 -0700 Subject: [PATCH 057/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 9beb2da25c..28e40487c1 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -67,14 +67,14 @@ def _cross_entropy_forward( mask = col_offsets < VOCAB_SIZE label_idx = tl.load(labels_ptr).to(tl.int32) - logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")) + logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32) # Go logit scaling for Cohere: t * x if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits # Do logit softcapping for Gemma 2: t * tanh(1/t * x) if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP) - logits = logits.to(tl.float32) + # logits = logits c = tl.max(logits, 0) logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0)) @@ -146,14 +146,14 @@ def _chunked_cross_entropy_forward( mask = col_offsets < VOCAB_SIZE label_idx = tl.load(labels_ptr).to(tl.int32) - logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")) + logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32) # Go logit scaling for Cohere: t * x if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits # Do logit softcapping for Gemma 2: t * tanh(1/t * x) if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP) - logits = logits.to(tl.float32) + # logits = logits.to(tl.float32) c = tl.max(logits, 0) logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0)) From 30cdf652d331d794f2eaf3e118cfb6372bb13591 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Oct 2024 16:55:24 -0700 Subject: [PATCH 058/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 28e40487c1..266ece4ca2 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -232,7 +232,7 @@ def _cross_entropy_backward( pass # Do logit softcapping for Gemma 2: t * tanh(1/t * x) - partial = 0.0 + partial = x if DO_SOFTCAPPING: # d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x) partial = triton_tanh(x / SOFTCAP) From 54b901bc50b7798a78ed862813d4d986dd4c1e24 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Oct 2024 17:19:30 -0700 Subject: [PATCH 059/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 266ece4ca2..66d5046deb 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -353,7 +353,7 @@ def backward(ctx, dlosses): pass pass -@torch._disable_dynamo +# @torch._disable_dynamo def fast_cross_entropy_loss( logits, labels, From 6db9d286d809cf8f29973d22e205d4bc0841a65a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Oct 2024 17:26:34 -0700 Subject: [PATCH 060/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 66d5046deb..9f16e8e605 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -273,8 +273,8 @@ def forward(ctx, logits, labels, logit_softcapping = 0, logit_scaling = 0): n_chunks = div + (mod != 0) losses = torch.empty(n_rows, dtype = torch.float32, device = logits.device) - DO_SOFTCAPPING = (logit_softcapping != 0) - DO_LOGIT_SCALING = (logit_scaling != 0) + DO_SOFTCAPPING = logit_softcapping != 0 + DO_LOGIT_SCALING = logit_scaling != 0 if n_chunks == 1: # For small vocabs <= 65336 like Llama, Mistral From 8aefcd0b5000efcd2406e8b56bea6770d3ed9f82 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 30 Oct 2024 17:35:20 -0700 Subject: [PATCH 061/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 9f16e8e605..8cbdbf2a2c 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -273,8 +273,8 @@ def forward(ctx, logits, labels, logit_softcapping = 0, logit_scaling = 0): n_chunks = div + (mod != 0) losses = torch.empty(n_rows, dtype = torch.float32, device = logits.device) - DO_SOFTCAPPING = logit_softcapping != 0 - DO_LOGIT_SCALING = logit_scaling != 0 + DO_SOFTCAPPING : bool = bool(logit_softcapping != 0) + DO_LOGIT_SCALING : bool = bool(logit_scaling != 0) if n_chunks == 1: # For small vocabs <= 65336 like Llama, Mistral From 7bf626b36b63bda24defd6d0b20c59bb6312fdbb Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 31 Oct 2024 01:23:41 -0700 Subject: [PATCH 062/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 8cbdbf2a2c..a13475e227 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -92,10 +92,10 @@ def _cross_entropy_forward( pass -# @triton.heuristics({ -# "DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING" ], -# "DO_LOGIT_SCALING": lambda args: args["DO_LOGIT_SCALING"], -# }) +@triton.heuristics({ + "DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING" ], + "DO_LOGIT_SCALING": lambda args: args["DO_LOGIT_SCALING"], +}) @triton.jit def _chunked_cross_entropy_forward( logits_ptr , @@ -106,9 +106,9 @@ def _chunked_cross_entropy_forward( VOCAB_SIZE , N_CHUNKS , BLOCK_SIZE : tl.constexpr, - DO_SOFTCAPPING , + DO_SOFTCAPPING : tl.constexpr(tl.int1), SOFTCAP , - DO_LOGIT_SCALING , + DO_LOGIT_SCALING : tl.constexpr(tl.int1), LOGIT_SCALE , ): """ @@ -189,9 +189,9 @@ def _cross_entropy_backward( labels_ptr , VOCAB_SIZE , BLOCK_SIZE : tl.constexpr, - DO_SOFTCAPPING , + DO_SOFTCAPPING : tl.constexpr(tl.int1), SOFTCAP , - DO_LOGIT_SCALING , + DO_LOGIT_SCALING : tl.constexpr(tl.int1), LOGIT_SCALE , ): """ @@ -347,7 +347,7 @@ def backward(ctx, dlosses): SOFTCAP = ctx.logit_softcapping, DO_LOGIT_SCALING = ctx.DO_LOGIT_SCALING, LOGIT_SCALE = ctx.logit_scaling, - num_warps = 8, + num_warps = 8, ) return logits, None, None, None, pass From d4557513032b582d9a4e4d9fc3efd9484f9705e8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 31 Oct 2024 01:56:03 -0700 Subject: [PATCH 063/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 64 +++++++++++++-------------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index a13475e227..f256918746 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -19,23 +19,23 @@ from transformers.models.llama.modeling_llama import logger from packaging.version import Version -# @triton.heuristics({ -# "DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING" ], -# "DO_LOGIT_SCALING": lambda args: args["DO_LOGIT_SCALING"], -# }) +@triton.heuristics({ + "DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING" ], + "DO_LOGIT_SCALING": lambda args: args["DO_LOGIT_SCALING"], +}) @triton.jit def _cross_entropy_forward( logits_ptr , - logits_row_stride , + logits_row_stride : tl.constexpr(tl.int64), loss_ptr , logsumexp_ptr , labels_ptr , - VOCAB_SIZE , + VOCAB_SIZE : tl.constexpr, BLOCK_SIZE : tl.constexpr, - DO_SOFTCAPPING , - SOFTCAP , - DO_LOGIT_SCALING , - LOGIT_SCALE , + DO_SOFTCAPPING : tl.constexpr(tl.int1), + SOFTCAP : tl.constexpr(tl.float32), + DO_LOGIT_SCALING : tl.constexpr(tl.int1), + LOGIT_SCALE : tl.constexpr(tl.float32), ): """ Cross Entropy Loss = 1/n sum [ -yi log(Pi) ] @@ -67,14 +67,14 @@ def _cross_entropy_forward( mask = col_offsets < VOCAB_SIZE label_idx = tl.load(labels_ptr).to(tl.int32) - logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32) + logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")) # Go logit scaling for Cohere: t * x if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits # Do logit softcapping for Gemma 2: t * tanh(1/t * x) - if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP) + if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits.to(tl.float32) / SOFTCAP).to(logits.dtype) - # logits = logits + logits = logits.to(tl.float32) c = tl.max(logits, 0) logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0)) @@ -99,7 +99,7 @@ def _cross_entropy_forward( @triton.jit def _chunked_cross_entropy_forward( logits_ptr , - logits_row_stride , + logits_row_stride : tl.constexpr(tl.int64), loss_ptr , logsumexp_ptr , labels_ptr , @@ -107,9 +107,9 @@ def _chunked_cross_entropy_forward( N_CHUNKS , BLOCK_SIZE : tl.constexpr, DO_SOFTCAPPING : tl.constexpr(tl.int1), - SOFTCAP , + SOFTCAP : tl.constexpr(tl.float32), DO_LOGIT_SCALING : tl.constexpr(tl.int1), - LOGIT_SCALE , + LOGIT_SCALE : tl.constexpr(tl.float32), ): """ 256K vocab divided in 4 chunks @@ -146,14 +146,14 @@ def _chunked_cross_entropy_forward( mask = col_offsets < VOCAB_SIZE label_idx = tl.load(labels_ptr).to(tl.int32) - logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32) + logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")) # Go logit scaling for Cohere: t * x if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits # Do logit softcapping for Gemma 2: t * tanh(1/t * x) - if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP) + if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits.to(tl.float32) / SOFTCAP).to(logits.dtype) - # logits = logits.to(tl.float32) + logits = logits.to(tl.float32) c = tl.max(logits, 0) logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0)) @@ -175,14 +175,14 @@ def _chunked_cross_entropy_forward( pass -# @triton.heuristics({ -# "DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING" ], -# "DO_LOGIT_SCALING": lambda args: args["DO_LOGIT_SCALING"], -# }) +@triton.heuristics({ + "DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING" ], + "DO_LOGIT_SCALING": lambda args: args["DO_LOGIT_SCALING"], +}) @triton.jit def _cross_entropy_backward( logits_ptr , - logits_row_stride , + logits_row_stride : tl.constexpr(tl.int64), dloss_ptr , dloss_row_stride , logsumexp_ptr , @@ -190,9 +190,9 @@ def _cross_entropy_backward( VOCAB_SIZE , BLOCK_SIZE : tl.constexpr, DO_SOFTCAPPING : tl.constexpr(tl.int1), - SOFTCAP , + SOFTCAP : tl.constexpr(tl.float32), DO_LOGIT_SCALING : tl.constexpr(tl.int1), - LOGIT_SCALE , + LOGIT_SCALE : tl.constexpr(tl.float32), ): """ CE_i = -y log(P) = y * (log[sum(exp(x))] - x) @@ -223,7 +223,7 @@ def _cross_entropy_backward( else: dloss = 0.0 - x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32) + x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")) # Do logit scaling for Cohere if DO_LOGIT_SCALING: @@ -235,12 +235,12 @@ def _cross_entropy_backward( partial = x if DO_SOFTCAPPING: # d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x) - partial = triton_tanh(x / SOFTCAP) + partial = triton_tanh(x.to(tl.float32) / SOFTCAP).to(x.dtype) x = SOFTCAP * partial pass logsumexp = tl.load(logsumexp_ptr + row_idx) - y = tl.exp(x - logsumexp) + y = tl.exp(x.to(tl.float32) - logsumexp) y = tl.where( col_offsets == label_idx, y - 1.0, # exp(x - logsumexp) - 1 @@ -271,7 +271,7 @@ def forward(ctx, logits, labels, logit_softcapping = 0, logit_scaling = 0): div, mod = divmod(vocab_size, MAX_FUSED_SIZE) n_chunks = div + (mod != 0) - losses = torch.empty(n_rows, dtype = torch.float32, device = logits.device) + losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0") DO_SOFTCAPPING : bool = bool(logit_softcapping != 0) DO_LOGIT_SCALING : bool = bool(logit_scaling != 0) @@ -279,7 +279,7 @@ def forward(ctx, logits, labels, logit_softcapping = 0, logit_scaling = 0): if n_chunks == 1: # For small vocabs <= 65336 like Llama, Mistral BLOCK_SIZE, num_warps = calculate_settings(vocab_size) - logsumexp = torch.empty(n_rows, dtype = torch.float32, device = logits.device) + logsumexp = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0") _cross_entropy_forward[(n_rows,)]( logits, logits.stride(0), @@ -296,7 +296,7 @@ def forward(ctx, logits, labels, logit_softcapping = 0, logit_scaling = 0): ) else: # For large vocabs > 65336 like Gemma 256K - logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = logits.device) + logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = "cuda:0") _chunked_cross_entropy_forward[(n_rows, n_chunks,)]( logits, logits.stride(0), From 055eeb8c47bf63a426fe64cecfd46247d27f6f1a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 31 Oct 2024 02:00:04 -0700 Subject: [PATCH 064/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index f256918746..f8b9b4245e 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -30,8 +30,8 @@ def _cross_entropy_forward( loss_ptr , logsumexp_ptr , labels_ptr , - VOCAB_SIZE : tl.constexpr, - BLOCK_SIZE : tl.constexpr, + VOCAB_SIZE : tl.constexpr(tl.int32), + BLOCK_SIZE : tl.constexpr(tl.int32), DO_SOFTCAPPING : tl.constexpr(tl.int1), SOFTCAP : tl.constexpr(tl.float32), DO_LOGIT_SCALING : tl.constexpr(tl.int1), @@ -103,9 +103,9 @@ def _chunked_cross_entropy_forward( loss_ptr , logsumexp_ptr , labels_ptr , - VOCAB_SIZE , - N_CHUNKS , - BLOCK_SIZE : tl.constexpr, + VOCAB_SIZE : tl.constexpr(tl.int32), + N_CHUNKS : tl.constexpr(tl.int32), + BLOCK_SIZE : tl.constexpr(tl.int32), DO_SOFTCAPPING : tl.constexpr(tl.int1), SOFTCAP : tl.constexpr(tl.float32), DO_LOGIT_SCALING : tl.constexpr(tl.int1), @@ -187,8 +187,8 @@ def _cross_entropy_backward( dloss_row_stride , logsumexp_ptr , labels_ptr , - VOCAB_SIZE , - BLOCK_SIZE : tl.constexpr, + VOCAB_SIZE : tl.constexpr(tl.int32), + BLOCK_SIZE : tl.constexpr(tl.int32), DO_SOFTCAPPING : tl.constexpr(tl.int1), SOFTCAP : tl.constexpr(tl.float32), DO_LOGIT_SCALING : tl.constexpr(tl.int1), From 8090b7c01aaceecac4263f9af2737fdb76ebd458 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 31 Oct 2024 12:36:21 -0700 Subject: [PATCH 065/209] Tied weights --- unsloth/models/llama.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index c0175bbfaf..0e9b70a8b2 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1838,7 +1838,8 @@ def post_patch(model, tokenizer): except: old_output_embedding = torch.zeros(0) # Check for tied weights as well - is_tied = old_input_embedding.data_ptr() == old_output_embedding.data_ptr() + # is_tied = old_input_embedding.data_ptr() == old_output_embedding.data_ptr() + is_tied = model.config.tie_word_embeddings # Check pad token's id -> we need to expand the embedding if len(tokenizer) > old_input_embedding.shape[0]: @@ -1887,6 +1888,9 @@ def post_patch(model, tokenizer): else: correct_dtype = old_input_embedding.dtype pass + + # Finally tie them if needed! + if is_tied: model.tie_weights() # Also patch all dtypes - BnB seems to not allocate the correct type? # BnB default dtype seems to be float16! From 7559efbbfd24037ba4501e1da7b6f9f18581b102 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 31 Oct 2024 12:38:20 -0700 Subject: [PATCH 066/209] Revert "Tied weights" This reverts commit 8090b7c01aaceecac4263f9af2737fdb76ebd458. --- unsloth/models/llama.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 0e9b70a8b2..c0175bbfaf 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1838,8 +1838,7 @@ def post_patch(model, tokenizer): except: old_output_embedding = torch.zeros(0) # Check for tied weights as well - # is_tied = old_input_embedding.data_ptr() == old_output_embedding.data_ptr() - is_tied = model.config.tie_word_embeddings + is_tied = old_input_embedding.data_ptr() == old_output_embedding.data_ptr() # Check pad token's id -> we need to expand the embedding if len(tokenizer) > old_input_embedding.shape[0]: @@ -1888,9 +1887,6 @@ def post_patch(model, tokenizer): else: correct_dtype = old_input_embedding.dtype pass - - # Finally tie them if needed! - if is_tied: model.tie_weights() # Also patch all dtypes - BnB seems to not allocate the correct type? # BnB default dtype seems to be float16! From ad63a32a0332cc25c4ee72cb8aab5e2fb4f10c8a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 31 Oct 2024 16:25:26 -0700 Subject: [PATCH 067/209] Tied weights --- unsloth/models/_utils.py | 30 ++++++++++++++++++++++++++++-- unsloth/models/llama.py | 7 ++++++- 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 0539e255ea..e1ed649560 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -479,6 +479,32 @@ def patch_tokenizer(model, tokenizer): if model is not None: model.config.update({"unsloth_version" : __version__}) + # First remove pad and unk tokens if they are known to be BOS / EOS + possible_bad_tokens = ( + "<|endoftext|>", + "<|im_start|>", + "<|im_end|>", + "<|begin_of_text|>", + "<|end_of_text|>", + "", + "", + ) + input_ids = tokenizer(list(possible_bad_tokens), add_special_tokens = False).input_ids + possible_bad_tokens = frozenset(token for token, input_id in zip(possible_bad_tokens, input_ids) if len(input_id) == 1) + + if hasattr(tokenizer, "pad_token") and tokenizer.pad_token in possible_bad_tokens: + print(f"Unsloth: Pad token was {tokenizer.pad_token} which is not a good idea. We shall fix this.") + tokenizer.pad_token = None + pass + + has_bad_unk_token = False + if hasattr(tokenizer, "unk_token") and tokenizer.unk_token in possible_bad_tokens: + print(f"Unsloth: Unk token was {tokenizer.unk_token} which is not a good idea. We shall fix this.") + tokenizer.unk_token = None + has_bad_unk_token = True + pass + + # Now check pad token again bad_pad_token = False if hasattr(tokenizer, "pad_token") and tokenizer.pad_token is not None: # Check if pad_token is not the same as eos_token otherwise the loss will ignore it!! @@ -492,13 +518,13 @@ def patch_tokenizer(model, tokenizer): # Check if unknown token is broken fixed_unk_token = False - if hasattr(tokenizer, "unk_token") and tokenizer.unk_token is not None: + if (hasattr(tokenizer, "unk_token") and tokenizer.unk_token is not None) or has_bad_unk_token: eos_token = getattr(tokenizer, "eos_token", None) bos_token = getattr(tokenizer, "bos_token", None) old_unk_token = tokenizer.unk_token - if old_unk_token == eos_token or old_unk_token == bos_token: + if (old_unk_token == eos_token) or (old_unk_token == bos_token) or has_bad_unk_token: has_broken_unk = True # Use the unicode replacement characters possible_replacements = [ diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index c0175bbfaf..8c4b7fd253 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1838,7 +1838,8 @@ def post_patch(model, tokenizer): except: old_output_embedding = torch.zeros(0) # Check for tied weights as well - is_tied = old_input_embedding.data_ptr() == old_output_embedding.data_ptr() + is_tied = (old_input_embedding.data_ptr() == old_output_embedding.data_ptr()) \ + or (model.config.tie_word_embeddings) # Check pad token's id -> we need to expand the embedding if len(tokenizer) > old_input_embedding.shape[0]: @@ -1887,6 +1888,10 @@ def post_patch(model, tokenizer): else: correct_dtype = old_input_embedding.dtype pass + + # Must tie lm_head and embed_tokens if they are tied! + # Otherwise error will occur on saving models ie use save_model + if is_tied: model.tie_weights() # Also patch all dtypes - BnB seems to not allocate the correct type? # BnB default dtype seems to be float16! From 35aa99261545731aa6c7728f53348579fedb1db0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 2 Nov 2024 19:08:57 -0700 Subject: [PATCH 068/209] Utils --- unsloth/kernels/cross_entropy_loss.py | 1 + unsloth/models/_utils.py | 369 +------------------------- 2 files changed, 11 insertions(+), 359 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index f8b9b4245e..f0dbb66f49 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -19,6 +19,7 @@ from transformers.models.llama.modeling_llama import logger from packaging.version import Version + @triton.heuristics({ "DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING" ], "DO_LOGIT_SCALING": lambda args: args["DO_LOGIT_SCALING"], diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index e1ed649560..4829e4d88c 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -56,6 +56,16 @@ import warnings, subprocess, re, inspect, psutil, os, math from packaging.version import Version +from unsloth_zoo.tokenizer_utils import ( + patch_tokenizer, +) +from unsloth_zoo.gradient_checkpointing import ( + Unsloth_Offloaded_Gradient_Checkpointer, + unsloth_offloaded_gradient_checkpoint, + patch_gradient_checkpointing, + unpatch_gradient_checkpointing, +) + # ============================================= # Disable some warnings which can get annoying warnings.filterwarnings(action = "ignore", category = UserWarning, module = "torch") @@ -131,7 +141,6 @@ def patch_mistral_nemo_config(config): # ============================================= # torch.cuda.amp.custom_fwd is deprecated >= 2.4 -import torch torch_version = torch.__version__ if Version(torch_version) < Version("2.4.0"): torch_amp_custom_fwd = torch.cuda.amp.custom_fwd @@ -457,228 +466,6 @@ def make_inputs_require_grad(module, input, output): return model pass - -def patch_tokenizer(model, tokenizer): - """ - Phi3's pad_token isn't set. We set it to <|placeholder... - Llama-3 is <|reserved... - Llama-2 is - Check if pad_token is not the same as eos_token otherwise the loss will ignore it!! - Fixes https://github.com/unslothai/unsloth/issues/5 - """ - possible_reserved_tokens = ( - "<|finetune_right_pad_id|>", # Llama-3.1 - "", # Mistral Nemo - "<|reserved", # Llama-3 - "<|placeholder", # Phi-3 - "[control", # Mistral type models - ) - joiner = "\1\0=+=\0\1" - number_repetitions = 3 - 1 # Number of reserved tokens needed - - if model is not None: - model.config.update({"unsloth_version" : __version__}) - - # First remove pad and unk tokens if they are known to be BOS / EOS - possible_bad_tokens = ( - "<|endoftext|>", - "<|im_start|>", - "<|im_end|>", - "<|begin_of_text|>", - "<|end_of_text|>", - "", - "", - ) - input_ids = tokenizer(list(possible_bad_tokens), add_special_tokens = False).input_ids - possible_bad_tokens = frozenset(token for token, input_id in zip(possible_bad_tokens, input_ids) if len(input_id) == 1) - - if hasattr(tokenizer, "pad_token") and tokenizer.pad_token in possible_bad_tokens: - print(f"Unsloth: Pad token was {tokenizer.pad_token} which is not a good idea. We shall fix this.") - tokenizer.pad_token = None - pass - - has_bad_unk_token = False - if hasattr(tokenizer, "unk_token") and tokenizer.unk_token in possible_bad_tokens: - print(f"Unsloth: Unk token was {tokenizer.unk_token} which is not a good idea. We shall fix this.") - tokenizer.unk_token = None - has_bad_unk_token = True - pass - - # Now check pad token again - bad_pad_token = False - if hasattr(tokenizer, "pad_token") and tokenizer.pad_token is not None: - # Check if pad_token is not the same as eos_token otherwise the loss will ignore it!! - bad_pad_token = tokenizer.eos_token == tokenizer.pad_token - elif hasattr(tokenizer, "pad_token") and tokenizer.pad_token is None: - bad_pad_token = True - else: - bad_pad_token = False - pass - - # Check if unknown token is broken - fixed_unk_token = False - - if (hasattr(tokenizer, "unk_token") and tokenizer.unk_token is not None) or has_bad_unk_token: - - eos_token = getattr(tokenizer, "eos_token", None) - bos_token = getattr(tokenizer, "bos_token", None) - - old_unk_token = tokenizer.unk_token - if (old_unk_token == eos_token) or (old_unk_token == bos_token) or has_bad_unk_token: - has_broken_unk = True - # Use the unicode replacement characters - possible_replacements = [ - "\uFFFD", # Original replacement char - "\uFFFC", # Another option - "\u2753", # Red Question mark emoji - "\u2754", # White Question mark emoji - "\u00BF", # Inverted question mark - ] - for replacement_char in possible_replacements: - char = tokenizer(replacement_char, add_special_tokens = False).input_ids - if len(char) == 1: - # Get actual token representation - try: char = tokenizer.convert_ids_to_tokens(char[0]) - except: continue - tokenizer.unk_token = char - fixed_unk_token = True - break - pass - pass - - if not fixed_unk_token: # Still broken! - raise RuntimeError( - f"Unsloth: Tried fixing the unk_token = {old_unk_token}, but couldn't!" - ) - pass - - logger.warning_once( - f"Unsloth: unk_token = {old_unk_token} is the same as the EOS or BOS tokens. "\ - f"We fixed it by changing it to {tokenizer.unk_token}." - ) - pass - pass - - if bad_pad_token: - # Find a better pad token - added_tokens = [str(x) for x in tokenizer.added_tokens_decoder.values()] - all_added_tokens = joiner.join(added_tokens[::-1]) - all_added_tokens += joiner - - final_pad_token = None - final_good_match = False - - for possible_reserved_token in possible_reserved_tokens: - possible_reserved_token = re.escape(possible_reserved_token) - found = re.finditer(f"{possible_reserved_token}", all_added_tokens) - first_match = None - good_match = False - for j, x in enumerate(found): - if j == 0: first_match = x - if j >= number_repetitions: - good_match = True - break - pass - pass - - if first_match is None: continue - - # If it ends with |> or > etc, then set it as a good pad token! - start = first_match.span(0)[0] - possible_pad_token = first_match.group(0) - end = all_added_tokens.find(joiner, start) - first_match = all_added_tokens[start:end] - - if first_match is not None: - good_match = possible_pad_token.endswith((">", "|>", "]", ")")) - pass - possible_pad_token = first_match - - # Replace current pad token if another exact match is found - if not final_good_match and good_match: - final_good_match = True - final_pad_token = possible_pad_token - break - else: - final_good_match = False - final_pad_token = possible_pad_token - pass - pass - possible_pad_token = final_pad_token - - # Try unk_token if it wasn't fixed - if possible_pad_token is None and not fixed_unk_token and hasattr(tokenizer, "unk_token"): - possible_pad_token = tokenizer.unk_token - pass - - # Check pad token's id must be less than vocab size - if possible_pad_token is not None: - check_pad_token = tokenizer(possible_pad_token, add_special_tokens = False).input_ids - if len(check_pad_token) != 1: - possible_pad_token = None - if model is not None and check_pad_token[0] >= model.config.vocab_size: - possible_pad_token = None - pass - - if possible_pad_token is None: - # Failure to find a good replacement!! We shall manually add one! - new_pad_token = "<|PAD_TOKEN|>" - while new_pad_token in tokenizer.get_vocab(): - new_pad_token = f"<{new_pad_token}>" - pass - possible_pad_token = new_pad_token - pass - - name = model.config._name_or_path if model is not None else "Model" - logger.warning_once( - f"{name} does not have a padding token! Will use pad_token = {possible_pad_token}." - ) - - # Edit pad_token - tokenizer.add_special_tokens({"pad_token" : possible_pad_token}) - tokenizer.pad_token = possible_pad_token - if model is not None: - - # Edit all config with new pad token - current_model = model - while hasattr(current_model, "model") and hasattr(current_model, "config"): - current_model.config.update({"pad_token_id" : tokenizer.pad_token_id}) - current_model = current_model.model - if hasattr(current_model, "model") and hasattr(current_model, "config"): - current_model.config.update({"pad_token_id" : tokenizer.pad_token_id}) - pass - - # Generation edit pad token - if getattr(model, "generation_config") is not None: - model.generation_config.update(pad_token_id = tokenizer.pad_token_id) - else: - if model is not None: - - if model.config.pad_token_id is None: - - # Edit all config with new pad token - current_model = model - while hasattr(current_model, "model") and hasattr(current_model, "config"): - current_model.config.update({"pad_token_id" : tokenizer.pad_token_id}) - current_model = model - if hasattr(current_model, "model") and hasattr(current_model, "config"): - current_model.config.update({"pad_token_id" : tokenizer.pad_token_id}) - pass - - # Generation edit pad token - if getattr(model, "generation_config") is not None: - model.generation_config.update(pad_token_id = tokenizer.pad_token_id) - pass - pass - - if model is not None: - if getattr(model, "generation_config") is not None: - model.generation_config.update(max_length = model.config.max_position_embeddings) - - return model, tokenizer -pass - - # ============================================= # Weirdly LoraLayer.update_layer downcasts PEFT layers to float16?? # For mixed precision, we need it to be in float32 not float16. @@ -820,142 +607,6 @@ def get_statistics(): pass -def _calculate_n_gradient_checkpoints( - n_layers : int, - method : Optional[Union[str, int]] = "sqrt", -) -> List[int]: - assert(type(n_layers) is int and n_layers > 0) - - if method is None: method = "sqrt" - - if method == "sqrt": - n_checkpoints = int(n_layers**0.5) - elif type(method) is int and method > 0: - n_checkpoints = int(np.ceil(n_layers / method)) - else: - raise ValueError("method must be 'sqrt' or an int >0 and <= n_layers.") - - size = n_layers // n_checkpoints - sizes = np.full(n_checkpoints, size, dtype = int) - leftovers = n_layers % n_checkpoints - # We append leftovers from the right - for k in range(leftovers): - sizes[n_checkpoints-1-k] += 1 - boundaries = np.hstack((0, np.cumsum(sizes))) - boundaries = boundaries.tolist() - return boundaries -pass - - -def calculate_n_gradient_checkpoints( - n_layers : int, - layers_per_checkpoint : Optional[Union[str, int]] = "sqrt", -) -> List[int]: - assert(type(n_layers) is int and n_layers > 0) - - if layers_per_checkpoint is None or layers_per_checkpoint == 1: - return None - - boundaries = _calculate_n_gradient_checkpoints(n_layers, layers_per_checkpoint) - - assert(boundaries[0] == 0 and boundaries[-1] == n_layers) - assert(min(boundaries) == 0 and max(boundaries) == n_layers) - assert(np.diff(boundaries).min() >= 0) - return boundaries -pass - - -def prepare_n_gradient_checkpoints( - model : Any, - layers_per_checkpoint : Optional[Union[str, int]] = "sqrt", - use_reentrant : Optional[bool] = True, -) -> None: - """ - Calculates where to place the gradient checkpoints given n_layers. - - Args: - model: Any LlamaModel with layers. - layers_per_checkpoint (`Union[str, int]`, *optional*): - Can either be `sqrt` or an integer for how many layers per checkpoint you want. - The more, the less memory usage, but can be slower. Default is `sqrt`. - Choose 1 for Pytorch gradient checkpointing. 2 to wrap 2 layers in 1 module etc. - use_reentrant (`bool`, *optional*): - https://github.com/pytorch/pytorch/blob/main/torch/utils/checkpoint.py#L354 - Optimal gradient checkpointing algorithm `use_reentrant=False` which will - be the default in future Pytorch versions doesn't seem to work?? - """ - _model = None - if hasattr(model, "layers"): - _model = model - elif hasattr(model, "model"): - if hasattr(model.model, "layers"): - _model = model.model - if _model is None: - raise TypeError("`model` or `model.model` does not have attribute `layers`. Are you sure this is a model?") - pass - - if use_reentrant is False: - use_reentrant = True - pass - - n_layers = len(_model.layers) - boundaries = calculate_n_gradient_checkpoints(n_layers, layers_per_checkpoint) - _model._gradient_checkpointing_boundaries = boundaries - _model._gradient_checkpointing_use_reentrant = use_reentrant -pass - - -class Unsloth_Offloaded_Gradient_Checkpointer(torch.autograd.Function): - """ - Saves VRAM by smartly offloading to RAM. - Tiny hit to performance, since we mask the movement via non blocking calls. - """ - @staticmethod - @torch_amp_custom_fwd - def forward(ctx, forward_function, hidden_states, *args): - saved_hidden_states = hidden_states.to("cpu", non_blocking = True) - with torch.no_grad(): - output = forward_function(hidden_states, *args) - ctx.save_for_backward(saved_hidden_states) - ctx.forward_function = forward_function - ctx.args = args - return output - pass - - @staticmethod - @torch_amp_custom_bwd - def backward(ctx, dY): - (hidden_states,) = ctx.saved_tensors - hidden_states = hidden_states.to("cuda:0", non_blocking = True).detach() - hidden_states.requires_grad_(True) - with torch.enable_grad(): - (output,) = ctx.forward_function(hidden_states, *ctx.args) - torch.autograd.backward(output, dY) - return (None, hidden_states.grad,) + (None,)*len(ctx.args) - pass -pass - - -# @torch._disable_dynamo -def unsloth_offloaded_gradient_checkpoint(function, *args, use_reentrant = None, **kwargs): - return Unsloth_Offloaded_Gradient_Checkpointer.apply(function, *args) -pass - -import torch.utils -def patch_gradient_checkpointing(): - if torch.utils.checkpoint.checkpoint.__name__ == "unsloth_offloaded_gradient_checkpoint": return - torch.utils.checkpoint._old_checkpoint = torch.utils.checkpoint.checkpoint - torch.utils.checkpoint.checkpoint = unsloth_offloaded_gradient_checkpoint -pass - -def unpatch_gradient_checkpointing(): - if hasattr(torch.utils.checkpoint, "_old_checkpoint"): - torch.utils.checkpoint.checkpoint = torch.utils.checkpoint._old_checkpoint - del torch.utils.checkpoint._old_checkpoint - pass -pass - - # ============================================= # Regional torch 2.5 Recompilation - weirdly very slow?? def patch_regional_compilation(): From 0172ee34efe92fab26e04ff3cd39120af7f3e852 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 2 Nov 2024 19:20:48 -0700 Subject: [PATCH 069/209] CE Loss patching --- unsloth/kernels/__init__.py | 4 +- unsloth/kernels/cross_entropy_loss.py | 158 ++------------------------ 2 files changed, 12 insertions(+), 150 deletions(-) diff --git a/unsloth/kernels/__init__.py b/unsloth/kernels/__init__.py index 6357ddaf87..9d5b2da4f9 100644 --- a/unsloth/kernels/__init__.py +++ b/unsloth/kernels/__init__.py @@ -14,9 +14,7 @@ from .cross_entropy_loss import ( fast_cross_entropy_loss, - patch_llama_for_causal_lm, - unpatch_llama_for_causal_lm, - patch_transformers_losses, + patch_losses, patch_loss_function, ) from .rms_layernorm import ( diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index f0dbb66f49..5236e51985 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -19,6 +19,12 @@ from transformers.models.llama.modeling_llama import logger from packaging.version import Version +from unsloth_zoo.loss_utils import ( + causal_loss_function, + transformers_losses_patcher, + patch_loss_function, +) + @triton.heuristics({ "DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING" ], @@ -354,7 +360,7 @@ def backward(ctx, dlosses): pass pass -# @torch._disable_dynamo + def fast_cross_entropy_loss( logits, labels, @@ -382,152 +388,10 @@ def fast_cross_entropy_loss( n_items = torch.count_nonzero(labels != -100) return loss.sum() / n_items pass -if Version(torch.__version__) < Version("2.5.0"): +if (Version(torch.__version__) < Version("2.4.0")) and \ + not hasattr(fast_cross_entropy_loss, "__wrapped__"): fast_cross_entropy_loss = torch._disable_dynamo(fast_cross_entropy_loss) pass - -from transformers.models.llama.modeling_llama import ( - LlamaForCausalLM, - CausalLMOutputWithPast, - Optional, - Union, - Cache, - List, - Tuple, -) - -# Transformers 4.47 need Unpack, KwargsForCausalLM -try: - from transformers.models.llama.modeling_llama import Unpack, KwargsForCausalLM -except: - pass -pass - -import inspect, re -function = inspect.getsource(LlamaForCausalLM.forward) -function = function.split("\n") -i = re.match(r"[ ]{1,}", function[0]).span(0)[1] -function = [x[i:] for x in function] -function = "\n".join(function) -function = function[function.find("def forward"):] -replacement = """ loss = None - logit_softcapping = getattr(self.config, "final_logit_softcapping", 0) - logit_scaling = getattr(self.config, "logit_scale", 0) - if labels is not None: - shift_logits = logits - if not hasattr(self, "extra_ignored_labels"): - # Fixes https://github.com/unslothai/unsloth/issues/10 - self.extra_ignored_labels = torch.full((self.max_seq_length, 1), -100, device = "cuda:0") - pass - - shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]])) - loss = fast_cross_entropy_loss( - logits = shift_logits, - labels = shift_labels, - logit_softcapping = logit_softcapping, - logit_scaling = logit_scaling, - n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None), - ) - else: - if logit_scaling != 0: - if logits.requires_grad: - logits = logit_scaling * logits - else: - logits *= logit_scaling - pass - pass - if logit_softcapping != 0: - if logits.requires_grad: - logits = (1.0 / logit_softcapping) * logits - logits = torch.tanh(logits) - logits = logit_softcapping * logits - else: - logits *= (1.0 / logit_softcapping) - torch.tanh(logits, out = logits) - logits *= logit_softcapping - pass - pass - pass -""" -function = \ - function[:function.find(" loss = None")] + \ - replacement + \ - function[ function.find(" if not return_dict"):] -function = function.replace("logits = logits.float()", "\n") -# Missed spaces -function = function.split("\n") -# Not the first one though! -function = [function[0]] + [" "*4 + x for x in function[1:]] -function = "\n".join(function) -function = f"class Unsloth_LlamaForCausalLM(LlamaForCausalLM):\n"\ -f" {function}\n" -exec(function, globals()) -del function, replacement, inspect, re - - -def patch_llama_for_causal_lm(): - import transformers.models.llama.modeling_llama - transformers.models.llama.modeling_llama.LlamaForCausalLM = Unsloth_LlamaForCausalLM - return -pass - - -def unpatch_llama_for_causal_lm(): - import transformers.models.llama.modeling_llama - transformers.models.llama.modeling_llama.LlamaForCausalLM = LlamaForCausalLM - return -pass - - -def UnslothForCausalLMLoss( - logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs -): - shift_logits = logits - shift_labels = torch.empty_like(labels) - shift_labels[..., :-1] = labels[..., 1:] - shift_labels[..., -1] = -100 - loss = fast_cross_entropy_loss( - logits = shift_logits, - labels = shift_labels, - n_items = num_items_in_batch, - ) - return loss -pass -if Version(torch.__version__) < Version("2.5.0"): - UnslothForCausalLMLoss = torch._disable_dynamo(UnslothForCausalLMLoss) -pass - - -def patch_transformers_losses(): - import re - try: - import transformers.loss.loss_utils - except: - logger.warning_once("Unsloth: Cannot patch loss functions - update transformers for faster modules!") - return - pass - - import transformers.modeling_utils - LOSS_MAPPING = transformers.loss.loss_utils.LOSS_MAPPING - LOSS_MAPPING["ForCausalLM"] = UnslothForCausalLMLoss - - # Remove @property and @lru_cache - if hasattr(transformers.modeling_utils.PreTrainedModel.loss_function, "fget"): - transformers.modeling_utils.PreTrainedModel.loss_function = \ - transformers.modeling_utils.PreTrainedModel.loss_function.fget.__wrapped__ - pass -pass - - -def patch_loss_function(model): - try: - # model.loss_function starts as a dict to a loss fx - # We invoke it to save it - model.loss_function = model.loss_function() - except: - # Failed means we already invoked it, and we need args to the loss fx - pass - pass - return model -pass +# Patch CE Losses in transformers +patch_losses = transformers_losses_patcher(causal_loss_function(fast_cross_entropy_loss)) From c228682c40033520cce9e7167237334f529b95d4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 2 Nov 2024 23:14:41 -0700 Subject: [PATCH 070/209] Update __init__.py --- unsloth/__init__.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 91ec460094..5102d8f466 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -53,12 +53,14 @@ # Reduce VRAM usage by reducing fragmentation os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" -# Hugging Face Hub faster downloads (only enable during Colab and Kaggle sessions) -keynames = "\n" + "\n".join(os.environ.keys()) -if "\nCOLAB_" in keynames or "\nKAGGLE_" in keynames: +# Hugging Face Hub faster downloads +if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ: os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" pass +# Log Unsloth is being used +os.environ["UNSLOTH_IS_PRESENT"] = "1" + try: import torch except ModuleNotFoundError: From 9aa221a421c9fba219bc5989537ccec0f993e284 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 00:22:38 -0700 Subject: [PATCH 071/209] Update __init__.py --- unsloth/kernels/__init__.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/unsloth/kernels/__init__.py b/unsloth/kernels/__init__.py index 9d5b2da4f9..df7589719d 100644 --- a/unsloth/kernels/__init__.py +++ b/unsloth/kernels/__init__.py @@ -54,8 +54,12 @@ create_flex_attention_sliding_window_mask, ) -try: - print("🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.") -except: - print("Unsloth: Will patch your computer to enable 2x faster free finetuning.") +import os +if "UNSLOTH_ZOO_IS_PRESENT" not in os.environ: + try: + print("🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.") + except: + print("Unsloth: Will patch your computer to enable 2x faster free finetuning.") + pass pass +del os From 751413ed00e70c5ad4ffbec96b61f5eaa20bba54 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 01:18:48 -0700 Subject: [PATCH 072/209] Patching --- unsloth/kernels/layernorm.py | 22 +----------- unsloth/models/_utils.py | 67 ++++++++++-------------------------- 2 files changed, 20 insertions(+), 69 deletions(-) diff --git a/unsloth/kernels/layernorm.py b/unsloth/kernels/layernorm.py index 48ade6d5ec..48d1a65ec7 100644 --- a/unsloth/kernels/layernorm.py +++ b/unsloth/kernels/layernorm.py @@ -17,6 +17,7 @@ import triton.language as tl import torch from .utils import calculate_settings +from unsloth_zoo import patch_layernorm @triton.jit @@ -162,27 +163,6 @@ def fast_layernorm(layernorm, X): pass -from torch.nn import LayerNorm -class Unsloth_LayerNorm(LayerNorm): - def forward(self, X): - return fast_layernorm(self, X) - pass -pass - - -def patch_layernorm(): - import torch.nn - torch.nn.LayerNorm = Unsloth_LayerNorm - return -pass - - -def unpatch_layernorm(): - import torch.nn - torch.nn.LayerNorm = LayerNorm - return -pass - def test_layernorm( dim = 1024, eps = 1e-5, dtype = torch.float16, diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 4829e4d88c..64cb222e61 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -59,6 +59,11 @@ from unsloth_zoo.tokenizer_utils import ( patch_tokenizer, ) +from unsloth_zoo.patching_utils import ( + patch_compiling_bitsandbytes, + patch_layernorm, + patch_torch_compile, +) from unsloth_zoo.gradient_checkpointing import ( Unsloth_Offloaded_Gradient_Checkpointer, unsloth_offloaded_gradient_checkpoint, @@ -356,47 +361,27 @@ def is_big_gpu(index): return True import torch._inductor.utils torch._inductor.utils.is_big_gpu = is_big_gpu +patch_torch_compile() - -# Torch compile arguments -torch_compile_arguments = [ - "config.dce = True", - "config.memory_planning = True", - "config.memory_pool = 'combined'", - "config.coordinate_descent_tuning = True", - "config.max_autotune_gemm = False", # GEMM is unnecessary - "config.autotune_multi_device = False", - "config.max_autotune_gemm_backends = 'TRITON,ATEN,CPP'", # Not much faster - "config.aggressive_fusion = False", # Careful changes results! - "config.cuda.enable_cuda_lto = True", - "config.cuda.use_fast_math = True", - "config.cuda.compile_opt_level = '-O2'", -] -# Torch dynamo arguments -torch_dynamo_arguments = [ - "config.accumulated_cache_size_limit = 1024", # Bump up a bit from 256 - "config.suppress_errors = True", # Supress errors for now - "config.do_not_emit_runtime_asserts = True", - "config.cache_size_limit = 1024", # Flex Attention - "config.inline_inbuilt_nn_modules = True", # Torch 2.5 Regional recompilation -] -# import torch._inductor.config as config -# for _try_compile_argument in torch_compile_arguments: -# try: exec(_try_compile_argument) -# except: pass -# pass -# import torch._dynamo.config as config -# for _try_dynamo_argument in torch_dynamo_arguments: -# try: exec(_try_dynamo_argument) -# except: pass -# pass torch_compile_options = { "epilogue_fusion" : True, "max_autotune" : True, "shape_padding" : True, - "trace.enabled" : False, # Output Triton kernel outputs! + "trace.enabled" : False, "triton.cudagraphs" : False, } + +import accelerate +def torch_compile_kwargs(*args, **kwargs): + print("Unsloth: Enabled auto compiling") + return {"dynamic" : True, "fullgraph" : False, "options" : torch_compile_options} +pass + +accelerate.utils.dataclasses.TorchDynamoPlugin.to_kwargs = torch_compile_kwargs +accelerate.utils.TorchDynamoPlugin.to_kwargs = torch_compile_kwargs +accelerate.accelerator.TorchDynamoPlugin.to_kwargs = torch_compile_kwargs +del accelerate + # ============================================= def prepare_model_for_kbit_training( @@ -499,22 +484,8 @@ def make_inputs_require_grad(module, input, output): pass pass -# Also disable compiling on bitsandbytes -def patch_compiling_bitsandbytes(): - # import peft.tuners.lora.bnb - # peft.tuners.lora.bnb.Linear4bit.forward = \ - # torch._disable_dynamo(peft.tuners.lora.bnb.Linear4bit.forward) - # peft.tuners.lora.bnb.Linear8bitLt.forward = \ - # torch._disable_dynamo(peft.tuners.lora.bnb.Linear8bitLt.forward) - # return - import bitsandbytes.nn.modules - bitsandbytes.nn.modules.Linear4bit.forward = \ - torch._disable_dynamo(bitsandbytes.nn.modules.Linear4bit.forward) - return -pass # ============================================= - import psutil def _get_statistics(statistics = None, force_download = True): # We log some basic stats about which environment is being used. From 82db087cab3918c35290c4001cb61b4265ca54fc Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 14:09:47 -0800 Subject: [PATCH 073/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 5236e51985..f5a015073c 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -20,9 +20,8 @@ from packaging.version import Version from unsloth_zoo.loss_utils import ( - causal_loss_function, - transformers_losses_patcher, - patch_loss_function, + patch_loss_functions, + post_patch_loss_function, ) From cf682022526762aaa1c887ac6c8055b448cf2f2b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 14:15:08 -0800 Subject: [PATCH 074/209] CE Loss --- unsloth/kernels/__init__.py | 3 +-- unsloth/kernels/cross_entropy_loss.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/unsloth/kernels/__init__.py b/unsloth/kernels/__init__.py index df7589719d..78e70a65b5 100644 --- a/unsloth/kernels/__init__.py +++ b/unsloth/kernels/__init__.py @@ -14,8 +14,7 @@ from .cross_entropy_loss import ( fast_cross_entropy_loss, - patch_losses, - patch_loss_function, + post_patch_loss_function, ) from .rms_layernorm import ( fast_rms_layernorm, diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index f5a015073c..92e64bfa9b 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -393,4 +393,4 @@ def fast_cross_entropy_loss( pass # Patch CE Losses in transformers -patch_losses = transformers_losses_patcher(causal_loss_function(fast_cross_entropy_loss)) +patch_loss_functions(fast_cross_entropy_loss) From 63a18286d7ccf2f6374daf12b68cd0e07b8871dd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 15:01:29 -0800 Subject: [PATCH 075/209] Update _utils.py --- unsloth/models/_utils.py | 22 +--------------------- 1 file changed, 1 insertion(+), 21 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 64cb222e61..61d660bce4 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -63,6 +63,7 @@ patch_compiling_bitsandbytes, patch_layernorm, patch_torch_compile, + patch_regional_compilation, ) from unsloth_zoo.gradient_checkpointing import ( Unsloth_Offloaded_Gradient_Checkpointer, @@ -578,27 +579,6 @@ def get_statistics(): pass -# ============================================= -# Regional torch 2.5 Recompilation - weirdly very slow?? -def patch_regional_compilation(): - if torch.nn.ModuleList.__name__ == "UnslothModuleList": return - # Only works for torch 2.5 - if Version(torch.__version__) < Version("2.5.0"): return - - old_module_list = torch.nn.ModuleList - - def UnslothModuleList(*args, **kwargs): - if len(args) == 1 and len(kwargs) == 0 and type(args[0]) is list: - args = [old_module_list([torch.compile(x, dynamic = True, options = torch_compile_options, fullgraph = False) for x in args[0]])] - return old_module_list(*args, **kwargs) - pass - UnslothModuleList.__doc__ = old_module_list.__doc__ - - torch.nn.ModuleList = UnslothModuleList - return -pass - - # ============================================= # Fixes Bitsandbytes to remove missing warnings from transformers.utils.quantization_config import BitsAndBytesConfig, QuantizationMethod From 3f0e56fc2313518f8bc348bc00d10748dd4b5162 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 15:02:38 -0800 Subject: [PATCH 076/209] Update _utils.py --- unsloth/models/_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 61d660bce4..f49c9db1ff 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -46,6 +46,8 @@ "patch_gradient_accumulation_fix", "patch_compiling_bitsandbytes", "patch_regional_compilation", + "patch_layernorm", + "patch_torch_compile", ] import torch From 1190ed45b3608b914d9de2b318118ddeb15c4b39 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 15:15:58 -0800 Subject: [PATCH 077/209] CE Loss --- unsloth/kernels/__init__.py | 1 + unsloth/kernels/cross_entropy_loss.py | 6 ++++-- unsloth/models/_utils.py | 1 + 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/unsloth/kernels/__init__.py b/unsloth/kernels/__init__.py index 78e70a65b5..3b31f49999 100644 --- a/unsloth/kernels/__init__.py +++ b/unsloth/kernels/__init__.py @@ -15,6 +15,7 @@ from .cross_entropy_loss import ( fast_cross_entropy_loss, post_patch_loss_function, + patch_loss_functions, ) from .rms_layernorm import ( fast_rms_layernorm, diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 92e64bfa9b..41bce690e1 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -20,7 +20,7 @@ from packaging.version import Version from unsloth_zoo.loss_utils import ( - patch_loss_functions, + patch_loss_functions as _patch_loss_functions, post_patch_loss_function, ) @@ -393,4 +393,6 @@ def fast_cross_entropy_loss( pass # Patch CE Losses in transformers -patch_loss_functions(fast_cross_entropy_loss) +def patch_loss_functions(): + _patch_loss_functions(fast_cross_entropy_loss) +pass diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index f49c9db1ff..45b9569c0d 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -48,6 +48,7 @@ "patch_regional_compilation", "patch_layernorm", "patch_torch_compile", + "patch_loss_functions", ] import torch From 607ac343518666674903247e0d5c2a04c0d4ad3b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 15:21:03 -0800 Subject: [PATCH 078/209] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 45b9569c0d..5d2bb3a523 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2024.10.7" +__version__ = "2024.11.1" __all__ = [ "prepare_model_for_kbit_training", From 32eac0b6c565d85aba376bd4460e20937fcb58e8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 15:25:18 -0800 Subject: [PATCH 079/209] Update _utils.py --- unsloth/models/_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 5d2bb3a523..0fb73c2237 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -48,7 +48,6 @@ "patch_regional_compilation", "patch_layernorm", "patch_torch_compile", - "patch_loss_functions", ] import torch From 5b6d401650752534e890425adb7997644b2357c2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 17:05:33 -0800 Subject: [PATCH 080/209] Layernorm --- unsloth/kernels/__init__.py | 1 - unsloth/kernels/layernorm.py | 4 +++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/unsloth/kernels/__init__.py b/unsloth/kernels/__init__.py index 3b31f49999..82e7641693 100644 --- a/unsloth/kernels/__init__.py +++ b/unsloth/kernels/__init__.py @@ -25,7 +25,6 @@ from .layernorm import ( fast_layernorm, patch_layernorm, - unpatch_layernorm, ) from .rope_embedding import fast_rope_embedding, inplace_rope_embedding from .swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel diff --git a/unsloth/kernels/layernorm.py b/unsloth/kernels/layernorm.py index 48d1a65ec7..a5f7926e2e 100644 --- a/unsloth/kernels/layernorm.py +++ b/unsloth/kernels/layernorm.py @@ -17,7 +17,9 @@ import triton.language as tl import torch from .utils import calculate_settings -from unsloth_zoo import patch_layernorm +from unsloth_zoo.patching_utils import ( + patch_layernorm, +) @triton.jit From 3d19a71cf96b09b56390a6bb0a2af55d573a43c4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 17:09:37 -0800 Subject: [PATCH 081/209] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 0fb73c2237..82bd4c979d 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -364,7 +364,7 @@ def is_big_gpu(index): return True import torch._inductor.utils torch._inductor.utils.is_big_gpu = is_big_gpu -patch_torch_compile() +patch_torch_compile(debug = False, O3 = False) torch_compile_options = { "epilogue_fusion" : True, From 76da5117d3358946748c2287c6212b68a26afa9f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 17:15:27 -0800 Subject: [PATCH 082/209] Update _utils.py --- unsloth/models/_utils.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 82bd4c979d..8f0fc30568 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -59,7 +59,7 @@ from packaging.version import Version from unsloth_zoo.tokenizer_utils import ( - patch_tokenizer, + patch_tokenizer as _patch_tokenizer, ) from unsloth_zoo.patching_utils import ( patch_compiling_bitsandbytes, @@ -983,3 +983,11 @@ def patch_gradient_accumulation_fix(Trainer): exec(function, globals()) Trainer.training_step = _unsloth_training_step pass + + +def patch_tokenizer(model, tokenizer): + model, tokenizer = _patch_tokenizer(model, tokenizer) + if model is not None: + model.config.update({"unsloth_version" : __version__}) + return model, tokenizer +pass From 013ebaa8769757c22ef1426b66922058983532af Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 17:49:08 -0800 Subject: [PATCH 083/209] Post patch --- unsloth/models/_utils.py | 1 + unsloth/models/gemma.py | 52 +--------------- unsloth/models/gemma2.py | 52 +--------------- unsloth/models/llama.py | 129 +-------------------------------------- 4 files changed, 6 insertions(+), 228 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 8f0fc30568..ba6392eab8 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -66,6 +66,7 @@ patch_layernorm, patch_torch_compile, patch_regional_compilation, + patch_model_and_tokenizer, ) from unsloth_zoo.gradient_checkpointing import ( Unsloth_Offloaded_Gradient_Checkpointer, diff --git a/unsloth/models/gemma.py b/unsloth/models/gemma.py index 1ec116b2ea..1d9a0c1334 100644 --- a/unsloth/models/gemma.py +++ b/unsloth/models/gemma.py @@ -340,56 +340,8 @@ def pre_patch(): @staticmethod def post_patch(model, tokenizer): - # Torch.compile fails on embedding matrix?? - # Workaround randomnly fixes it for torch versions < 2.2 - model.model.embed_tokens = torch.nn.Embedding.from_pretrained(model.model.embed_tokens.weight) - model.config.update({"unsloth_version" : __version__}) - - # We also do this for the lm_head - lm_head = torch.nn.Linear(1, 1, bias = None) - del lm_head.weight - lm_head.weight = model.lm_head.weight - lm_head.in_features = lm_head.weight.shape[1] - lm_head.out_features = lm_head.weight.shape[0] - model.lm_head = lm_head - - # Gemma has tied weights! This means lm_head == embed_tokens - if model.model.embed_tokens.weight.data_ptr() != model.lm_head.weight.data_ptr(): - lm_head = torch.nn.Linear(1, 1, bias = None) - del lm_head.weight - lm_head.weight = model.model.embed_tokens.weight - lm_head.in_features = lm_head.weight.shape[1] - lm_head.out_features = lm_head.weight.shape[0] - model.lm_head = lm_head - pass - - # Also patch all dtypes - BnB seems to not allocate the correct type? - # BnB default dtype seems to be float16! - correct_dtype = lm_head.weight.dtype - - for name, module in model.named_modules(): - if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)): - weight = module.weight - quant_state = weight.quant_state - - if type(quant_state) is list: - # BnB seems to have float16 as default! - module.weight.quant_state[2] = correct_dtype # Cast to correct dtype - else: - # https://github.com/TimDettmers/bitsandbytes/pull/763/files - quant_state.dtype = correct_dtype - pass - pass - # Downcast RoPE embedding to correct data type - # RoPE must be done in float32 for Gemma - # if (name.endswith("rotary_emb") or hasattr(module, "cos_cached")) \ - # and (module.cos_cached.dtype != correct_dtype): - - # module.cos_cached = module.cos_cached.to(correct_dtype) - # module.sin_cached = module.sin_cached.to(correct_dtype) - # pass - # pass - pass + # Gemma does not downcast RoPE + model, tokenizer = patch_model_and_tokenizer(model, tokenizer, downcast_rope = False) # Add 1 to weight # return output * (1 + self.weight) diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index 54d8f628cb..4eb9d64313 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -491,56 +491,8 @@ def pre_patch(): @staticmethod def post_patch(model, tokenizer): - # Torch.compile fails on embedding matrix?? - # Workaround randomnly fixes it for torch versions < 2.2 - model.model.embed_tokens = torch.nn.Embedding.from_pretrained(model.model.embed_tokens.weight) - model.config.update({"unsloth_version" : __version__}) - - # We also do this for the lm_head - lm_head = torch.nn.Linear(1, 1, bias = None) - del lm_head.weight - lm_head.weight = model.lm_head.weight - lm_head.in_features = lm_head.weight.shape[1] - lm_head.out_features = lm_head.weight.shape[0] - model.lm_head = lm_head - - # Gemma has tied weights! This means lm_head == embed_tokens - if model.model.embed_tokens.weight.data_ptr() != model.lm_head.weight.data_ptr(): - lm_head = torch.nn.Linear(1, 1, bias = None) - del lm_head.weight - lm_head.weight = model.model.embed_tokens.weight - lm_head.in_features = lm_head.weight.shape[1] - lm_head.out_features = lm_head.weight.shape[0] - model.lm_head = lm_head - pass - - # Also patch all dtypes - BnB seems to not allocate the correct type? - # BnB default dtype seems to be float16! - correct_dtype = lm_head.weight.dtype - - for name, module in model.named_modules(): - if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)): - weight = module.weight - quant_state = weight.quant_state - - if type(quant_state) is list: - # BnB seems to have float16 as default! - module.weight.quant_state[2] = correct_dtype # Cast to correct dtype - else: - # https://github.com/TimDettmers/bitsandbytes/pull/763/files - quant_state.dtype = correct_dtype - pass - pass - # Downcast RoPE embedding to correct data type - # RoPE must be done in float32 for Gemma - # if (name.endswith("rotary_emb") or hasattr(module, "cos_cached")) \ - # and (module.cos_cached.dtype != correct_dtype): - - # module.cos_cached = module.cos_cached.to(correct_dtype) - # module.sin_cached = module.sin_cached.to(correct_dtype) - # pass - # pass - pass + # Gemma does not downcast RoPE + model, tokenizer = patch_model_and_tokenizer(model, tokenizer, downcast_rope = False) # Add 1 to weight # return output * (1 + self.weight) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 8c4b7fd253..4e83e69d60 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -57,8 +57,6 @@ from transformers import set_seed as transformers_set_seed from peft import LoraConfig, TaskType, get_peft_model as _get_peft_model from peft import PeftModelForCausalLM -from bitsandbytes.nn import Linear4bit as Bnb_Linear4bit -from peft.tuners.lora import Linear4bit as Peft_Linear4bit from ..save import patch_saving_functions import re, os, inspect, math, sys try: @@ -1798,30 +1796,6 @@ def from_pretrained( internal_model = internal_model.model pass internal_model._saved_temp_tokenizer = tokenizer - - # Also fix torch_dtype - internal_model = model - while hasattr(internal_model, "model"): - if hasattr(internal_model, "config"): - if internal_model.config.torch_dtype == "float32": - internal_model.config.torch_dtype = torch.float32 - elif internal_model.config.torch_dtype == "bfloat16": - internal_model.config.torch_dtype = torch.bfloat16 - elif internal_model.config.torch_dtype == "float16": - internal_model.config.torch_dtype = torch.float16 - pass - pass - internal_model = internal_model.model - pass - if hasattr(internal_model, "config"): - if internal_model.config.torch_dtype == "float32": - internal_model.config.torch_dtype = torch.float32 - elif internal_model.config.torch_dtype == "bfloat16": - internal_model.config.torch_dtype = torch.bfloat16 - elif internal_model.config.torch_dtype == "float16": - internal_model.config.torch_dtype = torch.float16 - pass - pass return model, tokenizer pass @@ -1829,108 +1803,7 @@ def from_pretrained( @staticmethod def post_patch(model, tokenizer): - # Torch.compile fails on embedding matrix?? - try: old_input_embedding = model.get_input_embeddings ().weight - except: return model, tokenizer - - # Maybe not all models have a lm_head? - try: old_output_embedding = model.get_output_embeddings().weight - except: old_output_embedding = torch.zeros(0) - - # Check for tied weights as well - is_tied = (old_input_embedding.data_ptr() == old_output_embedding.data_ptr()) \ - or (model.config.tie_word_embeddings) - - # Check pad token's id -> we need to expand the embedding - if len(tokenizer) > old_input_embedding.shape[0]: - # Workaround randomnly fixes it for torch versions < 2. - requires_grad = old_input_embedding.requires_grad - old_input_embedding.requires_grad_(False) - old_input_embedding.resize_(len(tokenizer), old_input_embedding.shape[1]) - old_input_embedding.requires_grad_(requires_grad) - - # Fix up all vocab sizes - current_model = model - while hasattr(current_model, "model") and hasattr(current_model, "config"): - if hasattr(current_model.config, "vocab_size"): - current_model.config.update({"vocab_size" : len(tokenizer)}) - current_model = current_model.model - if hasattr(current_model, "model") and hasattr(current_model, "config"): - if hasattr(current_model.config, "vocab_size"): - current_model.config.update({"vocab_size" : len(tokenizer)}) - pass - pass - - model.set_input_embeddings( - torch.nn.Embedding.from_pretrained( - old_input_embedding, - padding_idx = getattr(model.config, "pad_token_id", None), - ) - ) - model.config.update({"unsloth_version" : __version__}) - - # We also do this for the lm_head - if old_output_embedding.numel() != 0: - - requires_grad = old_output_embedding.requires_grad - lm_head = torch.nn.Linear(1, 1, bias = None) - del lm_head.weight - - lm_head.weight = old_output_embedding if not is_tied else old_input_embedding - lm_head.in_features = lm_head.weight.shape[1] - lm_head.out_features = lm_head.weight.shape[0] - - lm_head.weight.requires_grad_(requires_grad) - model.set_output_embeddings(lm_head) - if hasattr(model, "lm_head"): model.lm_head = lm_head - - correct_dtype = lm_head.weight.dtype - else: - correct_dtype = old_input_embedding.dtype - pass - - # Must tie lm_head and embed_tokens if they are tied! - # Otherwise error will occur on saving models ie use save_model - if is_tied: model.tie_weights() - - # Also patch all dtypes - BnB seems to not allocate the correct type? - # BnB default dtype seems to be float16! - for name, module in model.named_modules(): - if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)): - weight = module.weight - quant_state = weight.quant_state - - if type(quant_state) is list: - # BnB seems to have float16 as default! - module.weight.quant_state[2] = correct_dtype # Cast to correct dtype - else: - # https://github.com/TimDettmers/bitsandbytes/pull/763/files - quant_state.dtype = correct_dtype - pass - pass - # Downcast RoPE embedding to correct data type - if (name.endswith("rotary_emb") or hasattr(module, "cos_cached")): - - if hasattr(module, "cos_cached") and \ - (module.cos_cached.dtype != correct_dtype): - - module.cos_cached = module.cos_cached.to(correct_dtype) - module.sin_cached = module.sin_cached.to(correct_dtype) - - elif hasattr(module, "short_cos_cached") and \ - (module.short_cos_cached.dtype != correct_dtype): - - module.short_cos_cached = module.short_cos_cached.to(correct_dtype) - module.short_sin_cached = module.short_sin_cached.to(correct_dtype) - pass - pass - pass - - # Clear deleted GPU items - for _ in range(3): - gc.collect() - torch.cuda.empty_cache() - return model, tokenizer + model, tokenizer = patch_model_and_tokenizer(model, tokenizer, downcast_rope = True) pass From 608916a83116f6968c4cfa1e388fc0170945945b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 17:56:03 -0800 Subject: [PATCH 084/209] Update _utils.py --- unsloth/models/_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index ba6392eab8..091dbaee29 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -48,6 +48,7 @@ "patch_regional_compilation", "patch_layernorm", "patch_torch_compile", + "patch_model_and_tokenizer", ] import torch From 19836e38bbc4b8367162161d902ff43a2b116db5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 18:08:22 -0800 Subject: [PATCH 085/209] Update llama.py --- unsloth/models/llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 4e83e69d60..3c4d8f3b38 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1804,6 +1804,7 @@ def from_pretrained( @staticmethod def post_patch(model, tokenizer): model, tokenizer = patch_model_and_tokenizer(model, tokenizer, downcast_rope = True) + return model, tokenizer pass From 01640876043b3f4de6382b39b9e9d986a73e1a5b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 18:13:11 -0800 Subject: [PATCH 086/209] Update _utils.py --- unsloth/models/_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 091dbaee29..837b4849f1 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -354,6 +354,7 @@ def _is_openai_available(): return False # ============================================= # Torch compile settings +UNSLOTH_COMPILE_DEBUG = True # Just remove max_autotune_gemm warning import functools @@ -366,13 +367,13 @@ def is_big_gpu(index): return True import torch._inductor.utils torch._inductor.utils.is_big_gpu = is_big_gpu -patch_torch_compile(debug = False, O3 = False) +patch_torch_compile(debug = UNSLOTH_COMPILE_DEBUG, O3 = False) torch_compile_options = { "epilogue_fusion" : True, "max_autotune" : True, "shape_padding" : True, - "trace.enabled" : False, + "trace.enabled" : UNSLOTH_COMPILE_DEBUG, "triton.cudagraphs" : False, } From 205f7ad7d923ca1bd62fce1094cbeacb3effcb40 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 18:33:02 -0800 Subject: [PATCH 087/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 41bce690e1..a8af945221 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -279,9 +279,6 @@ def forward(ctx, logits, labels, logit_softcapping = 0, logit_scaling = 0): n_chunks = div + (mod != 0) losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0") - DO_SOFTCAPPING : bool = bool(logit_softcapping != 0) - DO_LOGIT_SCALING : bool = bool(logit_scaling != 0) - if n_chunks == 1: # For small vocabs <= 65336 like Llama, Mistral BLOCK_SIZE, num_warps = calculate_settings(vocab_size) @@ -294,9 +291,9 @@ def forward(ctx, logits, labels, logit_softcapping = 0, logit_scaling = 0): labels, VOCAB_SIZE = vocab_size, BLOCK_SIZE = BLOCK_SIZE, - DO_SOFTCAPPING = DO_SOFTCAPPING, + DO_SOFTCAPPING = logit_softcapping != 0, SOFTCAP = logit_softcapping, - DO_LOGIT_SCALING = DO_LOGIT_SCALING, + DO_LOGIT_SCALING = logit_scaling != 0, LOGIT_SCALE = logit_scaling, num_warps = num_warps, ) @@ -312,9 +309,9 @@ def forward(ctx, logits, labels, logit_softcapping = 0, logit_scaling = 0): VOCAB_SIZE = vocab_size, N_CHUNKS = n_chunks, BLOCK_SIZE = MAX_FUSED_SIZE, - DO_SOFTCAPPING = DO_SOFTCAPPING, + DO_SOFTCAPPING = logit_softcapping != 0, SOFTCAP = logit_softcapping, - DO_LOGIT_SCALING = DO_LOGIT_SCALING, + DO_LOGIT_SCALING = logit_scaling != 0, LOGIT_SCALE = logit_scaling, num_warps = 32, ) @@ -326,9 +323,9 @@ def forward(ctx, logits, labels, logit_softcapping = 0, logit_scaling = 0): pass ctx.save_for_backward(logits, logsumexp, labels) - ctx.DO_SOFTCAPPING = DO_SOFTCAPPING + ctx.DO_SOFTCAPPING = logit_softcapping != 0 ctx.logit_softcapping = logit_softcapping - ctx.DO_LOGIT_SCALING = DO_LOGIT_SCALING + ctx.DO_LOGIT_SCALING = logit_scaling != 0 ctx.logit_scaling = logit_scaling return losses pass From 2f1f393dd7f50252450756c5bbe6bac5a40fb3ec Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 18:35:43 -0800 Subject: [PATCH 088/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index a8af945221..04a2e1861c 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -25,10 +25,10 @@ ) -@triton.heuristics({ - "DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING" ], - "DO_LOGIT_SCALING": lambda args: args["DO_LOGIT_SCALING"], -}) +# @triton.heuristics({ +# "DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING" ], +# "DO_LOGIT_SCALING": lambda args: args["DO_LOGIT_SCALING"], +# }) @triton.jit def _cross_entropy_forward( logits_ptr , @@ -98,10 +98,10 @@ def _cross_entropy_forward( pass -@triton.heuristics({ - "DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING" ], - "DO_LOGIT_SCALING": lambda args: args["DO_LOGIT_SCALING"], -}) +# @triton.heuristics({ +# "DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING" ], +# "DO_LOGIT_SCALING": lambda args: args["DO_LOGIT_SCALING"], +# }) @triton.jit def _chunked_cross_entropy_forward( logits_ptr , @@ -181,10 +181,10 @@ def _chunked_cross_entropy_forward( pass -@triton.heuristics({ - "DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING" ], - "DO_LOGIT_SCALING": lambda args: args["DO_LOGIT_SCALING"], -}) +# @triton.heuristics({ +# "DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING" ], +# "DO_LOGIT_SCALING": lambda args: args["DO_LOGIT_SCALING"], +# }) @triton.jit def _cross_entropy_backward( logits_ptr , From 05b8f663ef56c21608e304df3444e46bf6b77517 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 19:57:26 -0800 Subject: [PATCH 089/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 04a2e1861c..d33cb78409 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -279,6 +279,9 @@ def forward(ctx, logits, labels, logit_softcapping = 0, logit_scaling = 0): n_chunks = div + (mod != 0) losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0") + DO_SOFTCAPPING : bool = bool(logit_softcapping != 0) + DO_LOGIT_SCALING : bool = bool(logit_scaling != 0) + if n_chunks == 1: # For small vocabs <= 65336 like Llama, Mistral BLOCK_SIZE, num_warps = calculate_settings(vocab_size) @@ -291,9 +294,9 @@ def forward(ctx, logits, labels, logit_softcapping = 0, logit_scaling = 0): labels, VOCAB_SIZE = vocab_size, BLOCK_SIZE = BLOCK_SIZE, - DO_SOFTCAPPING = logit_softcapping != 0, + DO_SOFTCAPPING = DO_SOFTCAPPING, SOFTCAP = logit_softcapping, - DO_LOGIT_SCALING = logit_scaling != 0, + DO_LOGIT_SCALING = DO_LOGIT_SCALING, LOGIT_SCALE = logit_scaling, num_warps = num_warps, ) @@ -309,9 +312,9 @@ def forward(ctx, logits, labels, logit_softcapping = 0, logit_scaling = 0): VOCAB_SIZE = vocab_size, N_CHUNKS = n_chunks, BLOCK_SIZE = MAX_FUSED_SIZE, - DO_SOFTCAPPING = logit_softcapping != 0, + DO_SOFTCAPPING = DO_SOFTCAPPING, SOFTCAP = logit_softcapping, - DO_LOGIT_SCALING = logit_scaling != 0, + DO_LOGIT_SCALING = DO_LOGIT_SCALING, LOGIT_SCALE = logit_scaling, num_warps = 32, ) @@ -323,9 +326,9 @@ def forward(ctx, logits, labels, logit_softcapping = 0, logit_scaling = 0): pass ctx.save_for_backward(logits, logsumexp, labels) - ctx.DO_SOFTCAPPING = logit_softcapping != 0 + ctx.DO_SOFTCAPPING = DO_SOFTCAPPING ctx.logit_softcapping = logit_softcapping - ctx.DO_LOGIT_SCALING = logit_scaling != 0 + ctx.DO_LOGIT_SCALING = DO_LOGIT_SCALING ctx.logit_scaling = logit_scaling return losses pass From 8d205c0e0bb6d6913eded73081d0d7f62f732888 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 19:59:59 -0800 Subject: [PATCH 090/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index d33cb78409..70b0f116d5 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -279,9 +279,6 @@ def forward(ctx, logits, labels, logit_softcapping = 0, logit_scaling = 0): n_chunks = div + (mod != 0) losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0") - DO_SOFTCAPPING : bool = bool(logit_softcapping != 0) - DO_LOGIT_SCALING : bool = bool(logit_scaling != 0) - if n_chunks == 1: # For small vocabs <= 65336 like Llama, Mistral BLOCK_SIZE, num_warps = calculate_settings(vocab_size) @@ -294,9 +291,9 @@ def forward(ctx, logits, labels, logit_softcapping = 0, logit_scaling = 0): labels, VOCAB_SIZE = vocab_size, BLOCK_SIZE = BLOCK_SIZE, - DO_SOFTCAPPING = DO_SOFTCAPPING, + DO_SOFTCAPPING = bool(logit_softcapping != 0), SOFTCAP = logit_softcapping, - DO_LOGIT_SCALING = DO_LOGIT_SCALING, + DO_LOGIT_SCALING = bool(logit_scaling != 0), LOGIT_SCALE = logit_scaling, num_warps = num_warps, ) @@ -312,9 +309,9 @@ def forward(ctx, logits, labels, logit_softcapping = 0, logit_scaling = 0): VOCAB_SIZE = vocab_size, N_CHUNKS = n_chunks, BLOCK_SIZE = MAX_FUSED_SIZE, - DO_SOFTCAPPING = DO_SOFTCAPPING, + DO_SOFTCAPPING = bool(logit_softcapping != 0), SOFTCAP = logit_softcapping, - DO_LOGIT_SCALING = DO_LOGIT_SCALING, + DO_LOGIT_SCALING = bool(logit_scaling != 0), LOGIT_SCALE = logit_scaling, num_warps = 32, ) @@ -326,9 +323,9 @@ def forward(ctx, logits, labels, logit_softcapping = 0, logit_scaling = 0): pass ctx.save_for_backward(logits, logsumexp, labels) - ctx.DO_SOFTCAPPING = DO_SOFTCAPPING + ctx.DO_SOFTCAPPING = bool(logit_softcapping != 0) ctx.logit_softcapping = logit_softcapping - ctx.DO_LOGIT_SCALING = DO_LOGIT_SCALING + ctx.DO_LOGIT_SCALING = bool(logit_scaling != 0) ctx.logit_scaling = logit_scaling return losses pass From a1e9e135cd34ab2a4f8f3eed95ecc2b07dc8b440 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 20:03:03 -0800 Subject: [PATCH 091/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 70b0f116d5..347f9eb9da 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -346,9 +346,9 @@ def backward(ctx, dlosses): labels, VOCAB_SIZE = vocab_size, BLOCK_SIZE = BLOCK_SIZE, - DO_SOFTCAPPING = ctx.DO_SOFTCAPPING, + DO_SOFTCAPPING = bool(ctx.DO_SOFTCAPPING), SOFTCAP = ctx.logit_softcapping, - DO_LOGIT_SCALING = ctx.DO_LOGIT_SCALING, + DO_LOGIT_SCALING = bool(ctx.DO_LOGIT_SCALING), LOGIT_SCALE = ctx.logit_scaling, num_warps = 8, ) From 94655f8ddb7b623eeaad013bfc0da12c29001daa Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 20:49:59 -0800 Subject: [PATCH 092/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 44 +++++++++++++-------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 347f9eb9da..9c3c9442c9 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -32,16 +32,16 @@ @triton.jit def _cross_entropy_forward( logits_ptr , - logits_row_stride : tl.constexpr(tl.int64), + logits_row_stride , loss_ptr , logsumexp_ptr , labels_ptr , - VOCAB_SIZE : tl.constexpr(tl.int32), - BLOCK_SIZE : tl.constexpr(tl.int32), - DO_SOFTCAPPING : tl.constexpr(tl.int1), - SOFTCAP : tl.constexpr(tl.float32), - DO_LOGIT_SCALING : tl.constexpr(tl.int1), - LOGIT_SCALE : tl.constexpr(tl.float32), + VOCAB_SIZE , + BLOCK_SIZE : tl.constexpr, + DO_SOFTCAPPING , + SOFTCAP , + DO_LOGIT_SCALING , + LOGIT_SCALE , ): """ Cross Entropy Loss = 1/n sum [ -yi log(Pi) ] @@ -105,17 +105,17 @@ def _cross_entropy_forward( @triton.jit def _chunked_cross_entropy_forward( logits_ptr , - logits_row_stride : tl.constexpr(tl.int64), + logits_row_stride , loss_ptr , logsumexp_ptr , labels_ptr , - VOCAB_SIZE : tl.constexpr(tl.int32), - N_CHUNKS : tl.constexpr(tl.int32), - BLOCK_SIZE : tl.constexpr(tl.int32), - DO_SOFTCAPPING : tl.constexpr(tl.int1), - SOFTCAP : tl.constexpr(tl.float32), - DO_LOGIT_SCALING : tl.constexpr(tl.int1), - LOGIT_SCALE : tl.constexpr(tl.float32), + VOCAB_SIZE , + N_CHUNKS , + BLOCK_SIZE : tl.constexpr, + DO_SOFTCAPPING , + SOFTCAP , + DO_LOGIT_SCALING , + LOGIT_SCALE , ): """ 256K vocab divided in 4 chunks @@ -188,17 +188,17 @@ def _chunked_cross_entropy_forward( @triton.jit def _cross_entropy_backward( logits_ptr , - logits_row_stride : tl.constexpr(tl.int64), + logits_row_stride , dloss_ptr , dloss_row_stride , logsumexp_ptr , labels_ptr , - VOCAB_SIZE : tl.constexpr(tl.int32), - BLOCK_SIZE : tl.constexpr(tl.int32), - DO_SOFTCAPPING : tl.constexpr(tl.int1), - SOFTCAP : tl.constexpr(tl.float32), - DO_LOGIT_SCALING : tl.constexpr(tl.int1), - LOGIT_SCALE : tl.constexpr(tl.float32), + VOCAB_SIZE , + BLOCK_SIZE : tl.constexpr, + DO_SOFTCAPPING , + SOFTCAP , + DO_LOGIT_SCALING , + LOGIT_SCALE , ): """ CE_i = -y log(P) = y * (log[sum(exp(x))] - x) From 085f9988bab2f7fd36809d401f19d02d68185031 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 20:58:31 -0800 Subject: [PATCH 093/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 43 ++++++++++++++------------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 9c3c9442c9..35fbbed730 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -37,7 +37,7 @@ def _cross_entropy_forward( logsumexp_ptr , labels_ptr , VOCAB_SIZE , - BLOCK_SIZE : tl.constexpr, + BLOCK_SIZE : tl.constexpr , DO_SOFTCAPPING , SOFTCAP , DO_LOGIT_SCALING , @@ -69,7 +69,7 @@ def _cross_entropy_forward( logsumexp_ptr += row_idx labels_ptr += row_idx - col_offsets = tl.arange(0, BLOCK_SIZE) + col_offsets = tl.arange(0, BLOCK_SIZE : tl.constexpr) mask = col_offsets < VOCAB_SIZE label_idx = tl.load(labels_ptr).to(tl.int32) @@ -111,7 +111,7 @@ def _chunked_cross_entropy_forward( labels_ptr , VOCAB_SIZE , N_CHUNKS , - BLOCK_SIZE : tl.constexpr, + BLOCK_SIZE : tl.constexpr , DO_SOFTCAPPING , SOFTCAP , DO_LOGIT_SCALING , @@ -148,7 +148,7 @@ def _chunked_cross_entropy_forward( logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx labels_ptr += row_idx - col_offsets = chunk_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + col_offsets = chunk_idx*BLOCK_SIZE : tl.constexpr + tl.arange(0, BLOCK_SIZE : tl.constexpr) mask = col_offsets < VOCAB_SIZE label_idx = tl.load(labels_ptr).to(tl.int32) @@ -194,7 +194,7 @@ def _cross_entropy_backward( logsumexp_ptr , labels_ptr , VOCAB_SIZE , - BLOCK_SIZE : tl.constexpr, + BLOCK_SIZE : tl.constexpr , DO_SOFTCAPPING , SOFTCAP , DO_LOGIT_SCALING , @@ -220,7 +220,7 @@ def _cross_entropy_backward( logits_ptr += row_idx * logits_row_stride dloss_ptr += row_idx * dloss_row_stride - col_offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + col_offsets = block_idx*BLOCK_SIZE : tl.constexpr + tl.arange(0, BLOCK_SIZE : tl.constexpr) mask = col_offsets < VOCAB_SIZE label_idx = tl.load(labels_ptr + row_idx).to(tl.int32) @@ -279,9 +279,12 @@ def forward(ctx, logits, labels, logit_softcapping = 0, logit_scaling = 0): n_chunks = div + (mod != 0) losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0") + DO_SOFTCAPPING : bool = bool(logit_softcapping != 0) + DO_LOGIT_SCALING : bool = bool(logit_scaling != 0) + if n_chunks == 1: # For small vocabs <= 65336 like Llama, Mistral - BLOCK_SIZE, num_warps = calculate_settings(vocab_size) + BLOCK_SIZE : tl.constexpr, num_warps = calculate_settings(vocab_size) logsumexp = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0") _cross_entropy_forward[(n_rows,)]( @@ -290,10 +293,10 @@ def forward(ctx, logits, labels, logit_softcapping = 0, logit_scaling = 0): logsumexp, labels, VOCAB_SIZE = vocab_size, - BLOCK_SIZE = BLOCK_SIZE, - DO_SOFTCAPPING = bool(logit_softcapping != 0), + BLOCK_SIZE : tl.constexpr = BLOCK_SIZE : tl.constexpr, + DO_SOFTCAPPING = DO_SOFTCAPPING, SOFTCAP = logit_softcapping, - DO_LOGIT_SCALING = bool(logit_scaling != 0), + DO_LOGIT_SCALING = DO_LOGIT_SCALING, LOGIT_SCALE = logit_scaling, num_warps = num_warps, ) @@ -308,10 +311,10 @@ def forward(ctx, logits, labels, logit_softcapping = 0, logit_scaling = 0): labels, VOCAB_SIZE = vocab_size, N_CHUNKS = n_chunks, - BLOCK_SIZE = MAX_FUSED_SIZE, - DO_SOFTCAPPING = bool(logit_softcapping != 0), + BLOCK_SIZE : tl.constexpr = MAX_FUSED_SIZE, + DO_SOFTCAPPING = DO_SOFTCAPPING, SOFTCAP = logit_softcapping, - DO_LOGIT_SCALING = bool(logit_scaling != 0), + DO_LOGIT_SCALING = DO_LOGIT_SCALING, LOGIT_SCALE = logit_scaling, num_warps = 32, ) @@ -323,9 +326,9 @@ def forward(ctx, logits, labels, logit_softcapping = 0, logit_scaling = 0): pass ctx.save_for_backward(logits, logsumexp, labels) - ctx.DO_SOFTCAPPING = bool(logit_softcapping != 0) + ctx.DO_SOFTCAPPING = DO_SOFTCAPPING ctx.logit_softcapping = logit_softcapping - ctx.DO_LOGIT_SCALING = bool(logit_scaling != 0) + ctx.DO_LOGIT_SCALING = DO_LOGIT_SCALING ctx.logit_scaling = logit_scaling return losses pass @@ -335,8 +338,8 @@ def backward(ctx, dlosses): logits, logsumexp, labels = ctx.saved_tensors n_rows, vocab_size = logits.shape - BLOCK_SIZE = 4096 - div, mod = divmod(vocab_size, BLOCK_SIZE) + BLOCK_SIZE : tl.constexpr = 4096 + div, mod = divmod(vocab_size, BLOCK_SIZE : tl.constexpr) n_blocks = div + (mod != 0) _cross_entropy_backward[(n_rows, n_blocks,)]( @@ -345,10 +348,10 @@ def backward(ctx, dlosses): logsumexp, labels, VOCAB_SIZE = vocab_size, - BLOCK_SIZE = BLOCK_SIZE, - DO_SOFTCAPPING = bool(ctx.DO_SOFTCAPPING), + BLOCK_SIZE : tl.constexpr = BLOCK_SIZE : tl.constexpr, + DO_SOFTCAPPING = ctx.DO_SOFTCAPPING, SOFTCAP = ctx.logit_softcapping, - DO_LOGIT_SCALING = bool(ctx.DO_LOGIT_SCALING), + DO_LOGIT_SCALING = ctx.DO_LOGIT_SCALING, LOGIT_SCALE = ctx.logit_scaling, num_warps = 8, ) From c796fd9479550cfa1204a9b410014be56419c2da Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 20:58:59 -0800 Subject: [PATCH 094/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 35fbbed730..61a015d9ba 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -37,7 +37,7 @@ def _cross_entropy_forward( logsumexp_ptr , labels_ptr , VOCAB_SIZE , - BLOCK_SIZE : tl.constexpr , + BLOCK_SIZE : tl.constexpr, DO_SOFTCAPPING , SOFTCAP , DO_LOGIT_SCALING , @@ -69,7 +69,7 @@ def _cross_entropy_forward( logsumexp_ptr += row_idx labels_ptr += row_idx - col_offsets = tl.arange(0, BLOCK_SIZE : tl.constexpr) + col_offsets = tl.arange(0, BLOCK_SIZE) mask = col_offsets < VOCAB_SIZE label_idx = tl.load(labels_ptr).to(tl.int32) @@ -111,7 +111,7 @@ def _chunked_cross_entropy_forward( labels_ptr , VOCAB_SIZE , N_CHUNKS , - BLOCK_SIZE : tl.constexpr , + BLOCK_SIZE : tl.constexpr, DO_SOFTCAPPING , SOFTCAP , DO_LOGIT_SCALING , @@ -148,7 +148,7 @@ def _chunked_cross_entropy_forward( logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx labels_ptr += row_idx - col_offsets = chunk_idx*BLOCK_SIZE : tl.constexpr + tl.arange(0, BLOCK_SIZE : tl.constexpr) + col_offsets = chunk_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = col_offsets < VOCAB_SIZE label_idx = tl.load(labels_ptr).to(tl.int32) @@ -194,7 +194,7 @@ def _cross_entropy_backward( logsumexp_ptr , labels_ptr , VOCAB_SIZE , - BLOCK_SIZE : tl.constexpr , + BLOCK_SIZE : tl.constexpr, DO_SOFTCAPPING , SOFTCAP , DO_LOGIT_SCALING , @@ -220,7 +220,7 @@ def _cross_entropy_backward( logits_ptr += row_idx * logits_row_stride dloss_ptr += row_idx * dloss_row_stride - col_offsets = block_idx*BLOCK_SIZE : tl.constexpr + tl.arange(0, BLOCK_SIZE : tl.constexpr) + col_offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = col_offsets < VOCAB_SIZE label_idx = tl.load(labels_ptr + row_idx).to(tl.int32) @@ -284,7 +284,7 @@ def forward(ctx, logits, labels, logit_softcapping = 0, logit_scaling = 0): if n_chunks == 1: # For small vocabs <= 65336 like Llama, Mistral - BLOCK_SIZE : tl.constexpr, num_warps = calculate_settings(vocab_size) + BLOCK_SIZE, num_warps = calculate_settings(vocab_size) logsumexp = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0") _cross_entropy_forward[(n_rows,)]( @@ -293,7 +293,7 @@ def forward(ctx, logits, labels, logit_softcapping = 0, logit_scaling = 0): logsumexp, labels, VOCAB_SIZE = vocab_size, - BLOCK_SIZE : tl.constexpr = BLOCK_SIZE : tl.constexpr, + BLOCK_SIZE = BLOCK_SIZE, DO_SOFTCAPPING = DO_SOFTCAPPING, SOFTCAP = logit_softcapping, DO_LOGIT_SCALING = DO_LOGIT_SCALING, @@ -311,7 +311,7 @@ def forward(ctx, logits, labels, logit_softcapping = 0, logit_scaling = 0): labels, VOCAB_SIZE = vocab_size, N_CHUNKS = n_chunks, - BLOCK_SIZE : tl.constexpr = MAX_FUSED_SIZE, + BLOCK_SIZE = MAX_FUSED_SIZE, DO_SOFTCAPPING = DO_SOFTCAPPING, SOFTCAP = logit_softcapping, DO_LOGIT_SCALING = DO_LOGIT_SCALING, @@ -338,8 +338,8 @@ def backward(ctx, dlosses): logits, logsumexp, labels = ctx.saved_tensors n_rows, vocab_size = logits.shape - BLOCK_SIZE : tl.constexpr = 4096 - div, mod = divmod(vocab_size, BLOCK_SIZE : tl.constexpr) + BLOCK_SIZE = 4096 + div, mod = divmod(vocab_size, BLOCK_SIZE) n_blocks = div + (mod != 0) _cross_entropy_backward[(n_rows, n_blocks,)]( @@ -348,7 +348,7 @@ def backward(ctx, dlosses): logsumexp, labels, VOCAB_SIZE = vocab_size, - BLOCK_SIZE : tl.constexpr = BLOCK_SIZE : tl.constexpr, + BLOCK_SIZE = BLOCK_SIZE, DO_SOFTCAPPING = ctx.DO_SOFTCAPPING, SOFTCAP = ctx.logit_softcapping, DO_LOGIT_SCALING = ctx.DO_LOGIT_SCALING, From e943d77dbf29f76a71f4de172e935c31d7f1f369 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 21:03:24 -0800 Subject: [PATCH 095/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 32 +++++++++++++-------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 61a015d9ba..efe18f5c14 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -32,16 +32,16 @@ @triton.jit def _cross_entropy_forward( logits_ptr , - logits_row_stride , + logits_row_stride : tl.constexpr(tl.int64), loss_ptr , logsumexp_ptr , labels_ptr , - VOCAB_SIZE , - BLOCK_SIZE : tl.constexpr, + VOCAB_SIZE : tl.constexpr(tl.int32), + BLOCK_SIZE : tl.constexpr(tl.int32), DO_SOFTCAPPING , - SOFTCAP , + SOFTCAP : tl.constexpr(tl.float32), DO_LOGIT_SCALING , - LOGIT_SCALE , + LOGIT_SCALE : tl.constexpr(tl.float32), ): """ Cross Entropy Loss = 1/n sum [ -yi log(Pi) ] @@ -105,17 +105,17 @@ def _cross_entropy_forward( @triton.jit def _chunked_cross_entropy_forward( logits_ptr , - logits_row_stride , + logits_row_stride : tl.constexpr(tl.int64), loss_ptr , logsumexp_ptr , labels_ptr , - VOCAB_SIZE , - N_CHUNKS , - BLOCK_SIZE : tl.constexpr, + VOCAB_SIZE : tl.constexpr(tl.int32), + N_CHUNKS : tl.constexpr(tl.int32), + BLOCK_SIZE : tl.constexpr(tl.int32), DO_SOFTCAPPING , - SOFTCAP , + SOFTCAP : tl.constexpr(tl.float32), DO_LOGIT_SCALING , - LOGIT_SCALE , + LOGIT_SCALE : tl.constexpr(tl.float32), ): """ 256K vocab divided in 4 chunks @@ -188,17 +188,17 @@ def _chunked_cross_entropy_forward( @triton.jit def _cross_entropy_backward( logits_ptr , - logits_row_stride , + logits_row_stride : tl.constexpr(tl.int64), dloss_ptr , dloss_row_stride , logsumexp_ptr , labels_ptr , - VOCAB_SIZE , - BLOCK_SIZE : tl.constexpr, + VOCAB_SIZE : tl.constexpr(tl.int32), + BLOCK_SIZE : tl.constexpr(tl.int32), DO_SOFTCAPPING , - SOFTCAP , + SOFTCAP : tl.constexpr(tl.float32), DO_LOGIT_SCALING , - LOGIT_SCALE , + LOGIT_SCALE : tl.constexpr(tl.float32), ): """ CE_i = -y log(P) = y * (log[sum(exp(x))] - x) From 16a7df63e76eda7e3f46d5fa3eb99ed35acd46e6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 21:07:23 -0800 Subject: [PATCH 096/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 32 +++++++++++++-------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index efe18f5c14..17168b230a 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -32,16 +32,16 @@ @triton.jit def _cross_entropy_forward( logits_ptr , - logits_row_stride : tl.constexpr(tl.int64), + logits_row_stride : tl.constexpr, loss_ptr , logsumexp_ptr , labels_ptr , - VOCAB_SIZE : tl.constexpr(tl.int32), - BLOCK_SIZE : tl.constexpr(tl.int32), + VOCAB_SIZE : tl.constexpr, + BLOCK_SIZE : tl.constexpr, DO_SOFTCAPPING , - SOFTCAP : tl.constexpr(tl.float32), + SOFTCAP : tl.constexpr, DO_LOGIT_SCALING , - LOGIT_SCALE : tl.constexpr(tl.float32), + LOGIT_SCALE : tl.constexpr, ): """ Cross Entropy Loss = 1/n sum [ -yi log(Pi) ] @@ -105,17 +105,17 @@ def _cross_entropy_forward( @triton.jit def _chunked_cross_entropy_forward( logits_ptr , - logits_row_stride : tl.constexpr(tl.int64), + logits_row_stride : tl.constexpr, loss_ptr , logsumexp_ptr , labels_ptr , - VOCAB_SIZE : tl.constexpr(tl.int32), - N_CHUNKS : tl.constexpr(tl.int32), - BLOCK_SIZE : tl.constexpr(tl.int32), + VOCAB_SIZE : tl.constexpr, + N_CHUNKS : tl.constexpr, + BLOCK_SIZE : tl.constexpr, DO_SOFTCAPPING , - SOFTCAP : tl.constexpr(tl.float32), + SOFTCAP : tl.constexpr, DO_LOGIT_SCALING , - LOGIT_SCALE : tl.constexpr(tl.float32), + LOGIT_SCALE : tl.constexpr, ): """ 256K vocab divided in 4 chunks @@ -188,17 +188,17 @@ def _chunked_cross_entropy_forward( @triton.jit def _cross_entropy_backward( logits_ptr , - logits_row_stride : tl.constexpr(tl.int64), + logits_row_stride : tl.constexpr, dloss_ptr , dloss_row_stride , logsumexp_ptr , labels_ptr , - VOCAB_SIZE : tl.constexpr(tl.int32), - BLOCK_SIZE : tl.constexpr(tl.int32), + VOCAB_SIZE : tl.constexpr, + BLOCK_SIZE : tl.constexpr, DO_SOFTCAPPING , - SOFTCAP : tl.constexpr(tl.float32), + SOFTCAP : tl.constexpr, DO_LOGIT_SCALING , - LOGIT_SCALE : tl.constexpr(tl.float32), + LOGIT_SCALE : tl.constexpr, ): """ CE_i = -y log(P) = y * (log[sum(exp(x))] - x) From f65b064216555e37f3da18ef904169c237e744c3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 21:09:34 -0800 Subject: [PATCH 097/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 17168b230a..15a148b00e 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -36,12 +36,12 @@ def _cross_entropy_forward( loss_ptr , logsumexp_ptr , labels_ptr , - VOCAB_SIZE : tl.constexpr, + VOCAB_SIZE , BLOCK_SIZE : tl.constexpr, DO_SOFTCAPPING , - SOFTCAP : tl.constexpr, + SOFTCAP , DO_LOGIT_SCALING , - LOGIT_SCALE : tl.constexpr, + LOGIT_SCALE , ): """ Cross Entropy Loss = 1/n sum [ -yi log(Pi) ] @@ -109,13 +109,13 @@ def _chunked_cross_entropy_forward( loss_ptr , logsumexp_ptr , labels_ptr , - VOCAB_SIZE : tl.constexpr, - N_CHUNKS : tl.constexpr, + VOCAB_SIZE , + N_CHUNKS , BLOCK_SIZE : tl.constexpr, DO_SOFTCAPPING , - SOFTCAP : tl.constexpr, + SOFTCAP , DO_LOGIT_SCALING , - LOGIT_SCALE : tl.constexpr, + LOGIT_SCALE , ): """ 256K vocab divided in 4 chunks @@ -193,12 +193,12 @@ def _cross_entropy_backward( dloss_row_stride , logsumexp_ptr , labels_ptr , - VOCAB_SIZE : tl.constexpr, + VOCAB_SIZE , BLOCK_SIZE : tl.constexpr, DO_SOFTCAPPING , - SOFTCAP : tl.constexpr, + SOFTCAP , DO_LOGIT_SCALING , - LOGIT_SCALE : tl.constexpr, + LOGIT_SCALE , ): """ CE_i = -y log(P) = y * (log[sum(exp(x))] - x) From 1ff49b8984ec5226ea9d31630a6e6420ad7bc115 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 21:11:34 -0800 Subject: [PATCH 098/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 15a148b00e..bcdef31997 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -32,11 +32,11 @@ @triton.jit def _cross_entropy_forward( logits_ptr , - logits_row_stride : tl.constexpr, + logits_row_stride , loss_ptr , logsumexp_ptr , labels_ptr , - VOCAB_SIZE , + VOCAB_SIZE : tl.constexpr, BLOCK_SIZE : tl.constexpr, DO_SOFTCAPPING , SOFTCAP , @@ -105,12 +105,12 @@ def _cross_entropy_forward( @triton.jit def _chunked_cross_entropy_forward( logits_ptr , - logits_row_stride : tl.constexpr, + logits_row_stride , loss_ptr , logsumexp_ptr , labels_ptr , - VOCAB_SIZE , - N_CHUNKS , + VOCAB_SIZE : tl.constexpr, + N_CHUNKS : tl.constexpr, BLOCK_SIZE : tl.constexpr, DO_SOFTCAPPING , SOFTCAP , @@ -188,12 +188,12 @@ def _chunked_cross_entropy_forward( @triton.jit def _cross_entropy_backward( logits_ptr , - logits_row_stride : tl.constexpr, + logits_row_stride , dloss_ptr , dloss_row_stride , logsumexp_ptr , labels_ptr , - VOCAB_SIZE , + VOCAB_SIZE : tl.constexpr, BLOCK_SIZE : tl.constexpr, DO_SOFTCAPPING , SOFTCAP , From 080e558ddf98ea90a79c4d5c005b33cd0853f16e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 21:16:19 -0800 Subject: [PATCH 099/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index bcdef31997..4ac5a0c8b8 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -36,7 +36,7 @@ def _cross_entropy_forward( loss_ptr , logsumexp_ptr , labels_ptr , - VOCAB_SIZE : tl.constexpr, + VOCAB_SIZE , BLOCK_SIZE : tl.constexpr, DO_SOFTCAPPING , SOFTCAP , @@ -109,8 +109,8 @@ def _chunked_cross_entropy_forward( loss_ptr , logsumexp_ptr , labels_ptr , - VOCAB_SIZE : tl.constexpr, - N_CHUNKS : tl.constexpr, + VOCAB_SIZE , + N_CHUNKS , BLOCK_SIZE : tl.constexpr, DO_SOFTCAPPING , SOFTCAP , @@ -193,7 +193,7 @@ def _cross_entropy_backward( dloss_row_stride , logsumexp_ptr , labels_ptr , - VOCAB_SIZE : tl.constexpr, + VOCAB_SIZE , BLOCK_SIZE : tl.constexpr, DO_SOFTCAPPING , SOFTCAP , @@ -272,16 +272,20 @@ def _cross_entropy_backward( class Fast_CrossEntropyLoss(torch.autograd.Function): @staticmethod - def forward(ctx, logits, labels, logit_softcapping = 0, logit_scaling = 0): + def forward(ctx, logits, labels, logit_softcapping : float = 0, logit_scaling : float = 0): + n_rows : int + vocab_size : int n_rows, vocab_size = logits.shape div, mod = divmod(vocab_size, MAX_FUSED_SIZE) - n_chunks = div + (mod != 0) + n_chunks : int = div + (mod != 0) losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0") DO_SOFTCAPPING : bool = bool(logit_softcapping != 0) DO_LOGIT_SCALING : bool = bool(logit_scaling != 0) + BLOCK_SIZE : int + num_warps : int if n_chunks == 1: # For small vocabs <= 65336 like Llama, Mistral BLOCK_SIZE, num_warps = calculate_settings(vocab_size) @@ -336,11 +340,13 @@ def forward(ctx, logits, labels, logit_softcapping = 0, logit_scaling = 0): @staticmethod def backward(ctx, dlosses): logits, logsumexp, labels = ctx.saved_tensors + n_rows : int + vocab_size : int n_rows, vocab_size = logits.shape - BLOCK_SIZE = 4096 + BLOCK_SIZE : int = 4096 div, mod = divmod(vocab_size, BLOCK_SIZE) - n_blocks = div + (mod != 0) + n_blocks : int = div + (mod != 0) _cross_entropy_backward[(n_rows, n_blocks,)]( logits, logits.stride(0), From f6d50c78ac30e7236386a685a7a151a882cb3a50 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 21:18:26 -0800 Subject: [PATCH 100/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 4ac5a0c8b8..3d4a548a1f 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -37,10 +37,10 @@ def _cross_entropy_forward( logsumexp_ptr , labels_ptr , VOCAB_SIZE , - BLOCK_SIZE : tl.constexpr, - DO_SOFTCAPPING , + BLOCK_SIZE : tl.constexpr(tl.int32), + DO_SOFTCAPPING : tl.constexpr(tl.int1), SOFTCAP , - DO_LOGIT_SCALING , + DO_LOGIT_SCALING : tl.constexpr(tl.int1), LOGIT_SCALE , ): """ @@ -111,10 +111,10 @@ def _chunked_cross_entropy_forward( labels_ptr , VOCAB_SIZE , N_CHUNKS , - BLOCK_SIZE : tl.constexpr, - DO_SOFTCAPPING , + BLOCK_SIZE : tl.constexpr(tl.int32), + DO_SOFTCAPPING : tl.constexpr(tl.int1), SOFTCAP , - DO_LOGIT_SCALING , + DO_LOGIT_SCALING : tl.constexpr(tl.int1), LOGIT_SCALE , ): """ @@ -194,10 +194,10 @@ def _cross_entropy_backward( logsumexp_ptr , labels_ptr , VOCAB_SIZE , - BLOCK_SIZE : tl.constexpr, - DO_SOFTCAPPING , + BLOCK_SIZE : tl.constexpr(tl.int32), + DO_SOFTCAPPING : tl.constexpr(tl.int1), SOFTCAP , - DO_LOGIT_SCALING , + DO_LOGIT_SCALING : tl.constexpr(tl.int1), LOGIT_SCALE , ): """ From fad420255846915d1fb1ba30cd27eb2a01c5c735 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 21:19:42 -0800 Subject: [PATCH 101/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 3d4a548a1f..4ebe69d565 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -25,10 +25,10 @@ ) -# @triton.heuristics({ -# "DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING" ], -# "DO_LOGIT_SCALING": lambda args: args["DO_LOGIT_SCALING"], -# }) +@triton.heuristics({ + "DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]), + "DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]), +}) @triton.jit def _cross_entropy_forward( logits_ptr , @@ -98,10 +98,10 @@ def _cross_entropy_forward( pass -# @triton.heuristics({ -# "DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING" ], -# "DO_LOGIT_SCALING": lambda args: args["DO_LOGIT_SCALING"], -# }) +@triton.heuristics({ + "DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]), + "DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]), +}) @triton.jit def _chunked_cross_entropy_forward( logits_ptr , @@ -181,10 +181,10 @@ def _chunked_cross_entropy_forward( pass -# @triton.heuristics({ -# "DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING" ], -# "DO_LOGIT_SCALING": lambda args: args["DO_LOGIT_SCALING"], -# }) +@triton.heuristics({ + "DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]), + "DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]), +}) @triton.jit def _cross_entropy_backward( logits_ptr , From 736b16ac6850c55669902334e6e53b79c0f79a7b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 21:22:08 -0800 Subject: [PATCH 102/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 4ebe69d565..d396538e69 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -32,16 +32,16 @@ @triton.jit def _cross_entropy_forward( logits_ptr , - logits_row_stride , + logits_row_stride : tl.constexpr(tl.int64), loss_ptr , logsumexp_ptr , labels_ptr , - VOCAB_SIZE , + VOCAB_SIZE : tl.constexpr(tl.int32), BLOCK_SIZE : tl.constexpr(tl.int32), DO_SOFTCAPPING : tl.constexpr(tl.int1), - SOFTCAP , + SOFTCAP : tl.constexpr(tl.float32), DO_LOGIT_SCALING : tl.constexpr(tl.int1), - LOGIT_SCALE , + LOGIT_SCALE : tl.constexpr(tl.float32), ): """ Cross Entropy Loss = 1/n sum [ -yi log(Pi) ] @@ -105,17 +105,17 @@ def _cross_entropy_forward( @triton.jit def _chunked_cross_entropy_forward( logits_ptr , - logits_row_stride , + logits_row_stride : tl.constexpr(tl.int64), loss_ptr , logsumexp_ptr , labels_ptr , - VOCAB_SIZE , - N_CHUNKS , + VOCAB_SIZE : tl.constexpr(tl.int32), + N_CHUNKS : tl.constexpr(tl.int32), BLOCK_SIZE : tl.constexpr(tl.int32), DO_SOFTCAPPING : tl.constexpr(tl.int1), - SOFTCAP , + SOFTCAP : tl.constexpr(tl.float32), DO_LOGIT_SCALING : tl.constexpr(tl.int1), - LOGIT_SCALE , + LOGIT_SCALE : tl.constexpr(tl.float32), ): """ 256K vocab divided in 4 chunks @@ -188,17 +188,17 @@ def _chunked_cross_entropy_forward( @triton.jit def _cross_entropy_backward( logits_ptr , - logits_row_stride , + logits_row_stride : tl.constexpr(tl.int64), dloss_ptr , dloss_row_stride , logsumexp_ptr , labels_ptr , - VOCAB_SIZE , + VOCAB_SIZE : tl.constexpr(tl.int32), BLOCK_SIZE : tl.constexpr(tl.int32), DO_SOFTCAPPING : tl.constexpr(tl.int1), - SOFTCAP , + SOFTCAP : tl.constexpr(tl.float32), DO_LOGIT_SCALING : tl.constexpr(tl.int1), - LOGIT_SCALE , + LOGIT_SCALE : tl.constexpr(tl.float32), ): """ CE_i = -y log(P) = y * (log[sum(exp(x))] - x) From eb764169978a6af95c27a34cc2830d61c04cbcec Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 21:23:28 -0800 Subject: [PATCH 103/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index d396538e69..09eb0854e8 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -190,7 +190,7 @@ def _cross_entropy_backward( logits_ptr , logits_row_stride : tl.constexpr(tl.int64), dloss_ptr , - dloss_row_stride , + dloss_row_stride : tl.constexpr(tl.int32), logsumexp_ptr , labels_ptr , VOCAB_SIZE : tl.constexpr(tl.int32), From 367e43fe0a81a0cc17a4af050ca6b9cf0fe21c13 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 21:25:32 -0800 Subject: [PATCH 104/209] typing --- unsloth/kernels/cross_entropy_loss.py | 2 +- unsloth/kernels/rms_layernorm.py | 14 ++++++++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 09eb0854e8..d396538e69 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -190,7 +190,7 @@ def _cross_entropy_backward( logits_ptr , logits_row_stride : tl.constexpr(tl.int64), dloss_ptr , - dloss_row_stride : tl.constexpr(tl.int32), + dloss_row_stride , logsumexp_ptr , labels_ptr , VOCAB_SIZE : tl.constexpr(tl.int32), diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index 13faf08d6a..c0fb222b8a 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -53,7 +53,7 @@ def _rms_layernorm_forward( pass -@triton.heuristics({"GEMMA": lambda args: args["GEMMA"],}) +@triton.heuristics({"GEMMA": lambda args: bool(args["GEMMA"]),}) @triton.jit def _rms_layernorm_backward( dY, dY_row_stride, @@ -130,11 +130,15 @@ def _gemma_rms_layernorm_forward( class Fast_RMS_Layernorm(torch.autograd.Function): @staticmethod - def forward(ctx, X, W, eps, gemma = False): + def forward(ctx, X, W, eps :float, gemma : bool = False): shape = X.shape - dim = shape[-1] + dim : int = shape[-1] X = X.view(-1, dim) + n_rows : int + n_cols : int n_rows, n_cols = X.shape + BLOCK_SIZE : int + num_warps : int BLOCK_SIZE, num_warps = calculate_settings(n_cols) Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda:0") @@ -161,9 +165,11 @@ def forward(ctx, X, W, eps, gemma = False): @staticmethod def backward(ctx, dY): shape = dY.shape - dim = shape[-1] + dim : int = shape[-1] dY = dY.view(-1, dim) X, W, r = ctx.saved_tensors + n_rows : int + n_cols : int n_rows, n_cols = dY.shape dW = X From 993df2043cdaccae639b97350e7991d432ac3fa3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 21:27:37 -0800 Subject: [PATCH 105/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index d396538e69..f92ddbd05d 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -32,7 +32,7 @@ @triton.jit def _cross_entropy_forward( logits_ptr , - logits_row_stride : tl.constexpr(tl.int64), + logits_row_stride , loss_ptr , logsumexp_ptr , labels_ptr , @@ -105,7 +105,7 @@ def _cross_entropy_forward( @triton.jit def _chunked_cross_entropy_forward( logits_ptr , - logits_row_stride : tl.constexpr(tl.int64), + logits_row_stride , loss_ptr , logsumexp_ptr , labels_ptr , @@ -188,7 +188,7 @@ def _chunked_cross_entropy_forward( @triton.jit def _cross_entropy_backward( logits_ptr , - logits_row_stride : tl.constexpr(tl.int64), + logits_row_stride , dloss_ptr , dloss_row_stride , logsumexp_ptr , From 8f566b31b66a8071fec2231db3a00a16b7138f86 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 21:29:19 -0800 Subject: [PATCH 106/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index f92ddbd05d..4ebe69d565 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -36,12 +36,12 @@ def _cross_entropy_forward( loss_ptr , logsumexp_ptr , labels_ptr , - VOCAB_SIZE : tl.constexpr(tl.int32), + VOCAB_SIZE , BLOCK_SIZE : tl.constexpr(tl.int32), DO_SOFTCAPPING : tl.constexpr(tl.int1), - SOFTCAP : tl.constexpr(tl.float32), + SOFTCAP , DO_LOGIT_SCALING : tl.constexpr(tl.int1), - LOGIT_SCALE : tl.constexpr(tl.float32), + LOGIT_SCALE , ): """ Cross Entropy Loss = 1/n sum [ -yi log(Pi) ] @@ -109,13 +109,13 @@ def _chunked_cross_entropy_forward( loss_ptr , logsumexp_ptr , labels_ptr , - VOCAB_SIZE : tl.constexpr(tl.int32), - N_CHUNKS : tl.constexpr(tl.int32), + VOCAB_SIZE , + N_CHUNKS , BLOCK_SIZE : tl.constexpr(tl.int32), DO_SOFTCAPPING : tl.constexpr(tl.int1), - SOFTCAP : tl.constexpr(tl.float32), + SOFTCAP , DO_LOGIT_SCALING : tl.constexpr(tl.int1), - LOGIT_SCALE : tl.constexpr(tl.float32), + LOGIT_SCALE , ): """ 256K vocab divided in 4 chunks @@ -193,12 +193,12 @@ def _cross_entropy_backward( dloss_row_stride , logsumexp_ptr , labels_ptr , - VOCAB_SIZE : tl.constexpr(tl.int32), + VOCAB_SIZE , BLOCK_SIZE : tl.constexpr(tl.int32), DO_SOFTCAPPING : tl.constexpr(tl.int1), - SOFTCAP : tl.constexpr(tl.float32), + SOFTCAP , DO_LOGIT_SCALING : tl.constexpr(tl.int1), - LOGIT_SCALE : tl.constexpr(tl.float32), + LOGIT_SCALE , ): """ CE_i = -y log(P) = y * (log[sum(exp(x))] - x) From 22bb46b21b8ec5e824f174104d4d9c10ecd899da Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 21:31:18 -0800 Subject: [PATCH 107/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 4ebe69d565..7939758fda 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -39,9 +39,9 @@ def _cross_entropy_forward( VOCAB_SIZE , BLOCK_SIZE : tl.constexpr(tl.int32), DO_SOFTCAPPING : tl.constexpr(tl.int1), - SOFTCAP , + SOFTCAP : tl.constexpr(tl.float32), DO_LOGIT_SCALING : tl.constexpr(tl.int1), - LOGIT_SCALE , + LOGIT_SCALE : tl.constexpr(tl.float32), ): """ Cross Entropy Loss = 1/n sum [ -yi log(Pi) ] @@ -113,9 +113,9 @@ def _chunked_cross_entropy_forward( N_CHUNKS , BLOCK_SIZE : tl.constexpr(tl.int32), DO_SOFTCAPPING : tl.constexpr(tl.int1), - SOFTCAP , + SOFTCAP : tl.constexpr(tl.float32), DO_LOGIT_SCALING : tl.constexpr(tl.int1), - LOGIT_SCALE , + LOGIT_SCALE : tl.constexpr(tl.float32), ): """ 256K vocab divided in 4 chunks @@ -196,9 +196,9 @@ def _cross_entropy_backward( VOCAB_SIZE , BLOCK_SIZE : tl.constexpr(tl.int32), DO_SOFTCAPPING : tl.constexpr(tl.int1), - SOFTCAP , + SOFTCAP : tl.constexpr(tl.float32), DO_LOGIT_SCALING : tl.constexpr(tl.int1), - LOGIT_SCALE , + LOGIT_SCALE : tl.constexpr(tl.float32), ): """ CE_i = -y log(P) = y * (log[sum(exp(x))] - x) From b5c9f8193e050181328b5d53356862b269009cff Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 21:33:26 -0800 Subject: [PATCH 108/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 7939758fda..eea6e86065 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -32,16 +32,16 @@ @triton.jit def _cross_entropy_forward( logits_ptr , - logits_row_stride , + logits_row_stride : tl.constexpr(tl.int64), loss_ptr , logsumexp_ptr , labels_ptr , - VOCAB_SIZE , + VOCAB_SIZE : tl.constexpr(tl.int32), BLOCK_SIZE : tl.constexpr(tl.int32), DO_SOFTCAPPING : tl.constexpr(tl.int1), - SOFTCAP : tl.constexpr(tl.float32), + SOFTCAP , DO_LOGIT_SCALING : tl.constexpr(tl.int1), - LOGIT_SCALE : tl.constexpr(tl.float32), + LOGIT_SCALE , ): """ Cross Entropy Loss = 1/n sum [ -yi log(Pi) ] @@ -105,17 +105,17 @@ def _cross_entropy_forward( @triton.jit def _chunked_cross_entropy_forward( logits_ptr , - logits_row_stride , + logits_row_stride : tl.constexpr(tl.int64), loss_ptr , logsumexp_ptr , labels_ptr , - VOCAB_SIZE , - N_CHUNKS , + VOCAB_SIZE : tl.constexpr(tl.int32), + N_CHUNKS : tl.constexpr(tl.int32), BLOCK_SIZE : tl.constexpr(tl.int32), DO_SOFTCAPPING : tl.constexpr(tl.int1), - SOFTCAP : tl.constexpr(tl.float32), + SOFTCAP , DO_LOGIT_SCALING : tl.constexpr(tl.int1), - LOGIT_SCALE : tl.constexpr(tl.float32), + LOGIT_SCALE , ): """ 256K vocab divided in 4 chunks @@ -188,17 +188,17 @@ def _chunked_cross_entropy_forward( @triton.jit def _cross_entropy_backward( logits_ptr , - logits_row_stride , + logits_row_stride : tl.constexpr(tl.int64), dloss_ptr , dloss_row_stride , logsumexp_ptr , labels_ptr , - VOCAB_SIZE , + VOCAB_SIZE : tl.constexpr(tl.int32), BLOCK_SIZE : tl.constexpr(tl.int32), DO_SOFTCAPPING : tl.constexpr(tl.int1), - SOFTCAP : tl.constexpr(tl.float32), + SOFTCAP , DO_LOGIT_SCALING : tl.constexpr(tl.int1), - LOGIT_SCALE : tl.constexpr(tl.float32), + LOGIT_SCALE , ): """ CE_i = -y log(P) = y * (log[sum(exp(x))] - x) From c7b22206e8ec4f15f63e812f70a3090c99f712fd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 21:35:01 -0800 Subject: [PATCH 109/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index eea6e86065..9365b3f5c8 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -32,7 +32,7 @@ @triton.jit def _cross_entropy_forward( logits_ptr , - logits_row_stride : tl.constexpr(tl.int64), + logits_row_stride , loss_ptr , logsumexp_ptr , labels_ptr , @@ -105,7 +105,7 @@ def _cross_entropy_forward( @triton.jit def _chunked_cross_entropy_forward( logits_ptr , - logits_row_stride : tl.constexpr(tl.int64), + logits_row_stride , loss_ptr , logsumexp_ptr , labels_ptr , @@ -188,7 +188,7 @@ def _chunked_cross_entropy_forward( @triton.jit def _cross_entropy_backward( logits_ptr , - logits_row_stride : tl.constexpr(tl.int64), + logits_row_stride , dloss_ptr , dloss_row_stride , logsumexp_ptr , From 2d0ab26c10b970426b4c074fd6e94c82e6561833 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 21:35:27 -0800 Subject: [PATCH 110/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 9365b3f5c8..19b384a3bd 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -64,7 +64,7 @@ def _cross_entropy_forward( This ensures exp(x - max(x))'s maximum is 1 as exp(0) = 1. """ row_idx = tl.program_id(0) - logits_ptr += row_idx * logits_row_stride + logits_ptr += row_idx * tl.int64(logits_row_stride) loss_ptr += row_idx logsumexp_ptr += row_idx labels_ptr += row_idx @@ -143,7 +143,7 @@ def _chunked_cross_entropy_forward( """ row_idx = tl.program_id(0) chunk_idx = tl.program_id(1) - logits_ptr += row_idx * logits_row_stride + logits_ptr += row_idx * tl.int64(logits_row_stride) loss_ptr += row_idx logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx labels_ptr += row_idx @@ -218,7 +218,7 @@ def _cross_entropy_backward( row_idx = tl.program_id(0) block_idx = tl.program_id(1) - logits_ptr += row_idx * logits_row_stride + logits_ptr += row_idx * tl.int64(logits_row_stride) dloss_ptr += row_idx * dloss_row_stride col_offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = col_offsets < VOCAB_SIZE From 428f662c12261dc073256efbfd091b56ef628a00 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 21:38:19 -0800 Subject: [PATCH 111/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 19b384a3bd..8b2204436b 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -38,9 +38,9 @@ def _cross_entropy_forward( labels_ptr , VOCAB_SIZE : tl.constexpr(tl.int32), BLOCK_SIZE : tl.constexpr(tl.int32), - DO_SOFTCAPPING : tl.constexpr(tl.int1), + DO_SOFTCAPPING , SOFTCAP , - DO_LOGIT_SCALING : tl.constexpr(tl.int1), + DO_LOGIT_SCALING , LOGIT_SCALE , ): """ @@ -112,9 +112,9 @@ def _chunked_cross_entropy_forward( VOCAB_SIZE : tl.constexpr(tl.int32), N_CHUNKS : tl.constexpr(tl.int32), BLOCK_SIZE : tl.constexpr(tl.int32), - DO_SOFTCAPPING : tl.constexpr(tl.int1), + DO_SOFTCAPPING , SOFTCAP , - DO_LOGIT_SCALING : tl.constexpr(tl.int1), + DO_LOGIT_SCALING , LOGIT_SCALE , ): """ @@ -195,9 +195,9 @@ def _cross_entropy_backward( labels_ptr , VOCAB_SIZE : tl.constexpr(tl.int32), BLOCK_SIZE : tl.constexpr(tl.int32), - DO_SOFTCAPPING : tl.constexpr(tl.int1), + DO_SOFTCAPPING , SOFTCAP , - DO_LOGIT_SCALING : tl.constexpr(tl.int1), + DO_LOGIT_SCALING , LOGIT_SCALE , ): """ From 5023ce908670a380e7bfe0fa57d38952039f9d73 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 21:40:13 -0800 Subject: [PATCH 112/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 8b2204436b..a7ef164df0 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -64,7 +64,7 @@ def _cross_entropy_forward( This ensures exp(x - max(x))'s maximum is 1 as exp(0) = 1. """ row_idx = tl.program_id(0) - logits_ptr += row_idx * tl.int64(logits_row_stride) + logits_ptr += row_idx * logits_row_stride loss_ptr += row_idx logsumexp_ptr += row_idx labels_ptr += row_idx @@ -143,7 +143,7 @@ def _chunked_cross_entropy_forward( """ row_idx = tl.program_id(0) chunk_idx = tl.program_id(1) - logits_ptr += row_idx * tl.int64(logits_row_stride) + logits_ptr += row_idx * logits_row_stride loss_ptr += row_idx logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx labels_ptr += row_idx @@ -218,7 +218,7 @@ def _cross_entropy_backward( row_idx = tl.program_id(0) block_idx = tl.program_id(1) - logits_ptr += row_idx * tl.int64(logits_row_stride) + logits_ptr += row_idx * logits_row_stride dloss_ptr += row_idx * dloss_row_stride col_offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = col_offsets < VOCAB_SIZE From 5ca3d4ad92c6dd76b54bf1f990beecd746d6200e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 21:42:13 -0800 Subject: [PATCH 113/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index a7ef164df0..b4780db278 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -36,7 +36,7 @@ def _cross_entropy_forward( loss_ptr , logsumexp_ptr , labels_ptr , - VOCAB_SIZE : tl.constexpr(tl.int32), + VOCAB_SIZE , BLOCK_SIZE : tl.constexpr(tl.int32), DO_SOFTCAPPING , SOFTCAP , @@ -109,8 +109,8 @@ def _chunked_cross_entropy_forward( loss_ptr , logsumexp_ptr , labels_ptr , - VOCAB_SIZE : tl.constexpr(tl.int32), - N_CHUNKS : tl.constexpr(tl.int32), + VOCAB_SIZE , + N_CHUNKS , BLOCK_SIZE : tl.constexpr(tl.int32), DO_SOFTCAPPING , SOFTCAP , @@ -193,7 +193,7 @@ def _cross_entropy_backward( dloss_row_stride , logsumexp_ptr , labels_ptr , - VOCAB_SIZE : tl.constexpr(tl.int32), + VOCAB_SIZE , BLOCK_SIZE : tl.constexpr(tl.int32), DO_SOFTCAPPING , SOFTCAP , From 3b32d81f9ed40ed81cf98f5e735a14c43bd3ba66 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 23:53:03 -0800 Subject: [PATCH 114/209] int64 --- unsloth/kernels/cross_entropy_loss.py | 20 ++++++++++---------- unsloth/models/_utils.py | 13 ++++++++++++- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index b4780db278..11e582711d 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -36,8 +36,8 @@ def _cross_entropy_forward( loss_ptr , logsumexp_ptr , labels_ptr , - VOCAB_SIZE , - BLOCK_SIZE : tl.constexpr(tl.int32), + VOCAB_SIZE : tl.constexpr, + BLOCK_SIZE : tl.constexpr, DO_SOFTCAPPING , SOFTCAP , DO_LOGIT_SCALING , @@ -64,7 +64,7 @@ def _cross_entropy_forward( This ensures exp(x - max(x))'s maximum is 1 as exp(0) = 1. """ row_idx = tl.program_id(0) - logits_ptr += row_idx * logits_row_stride + logits_ptr += row_idx * logits_row_stride.to(tl.int64) loss_ptr += row_idx logsumexp_ptr += row_idx labels_ptr += row_idx @@ -109,9 +109,9 @@ def _chunked_cross_entropy_forward( loss_ptr , logsumexp_ptr , labels_ptr , - VOCAB_SIZE , - N_CHUNKS , - BLOCK_SIZE : tl.constexpr(tl.int32), + VOCAB_SIZE : tl.constexpr, + N_CHUNKS : tl.constexpr, + BLOCK_SIZE : tl.constexpr, DO_SOFTCAPPING , SOFTCAP , DO_LOGIT_SCALING , @@ -143,7 +143,7 @@ def _chunked_cross_entropy_forward( """ row_idx = tl.program_id(0) chunk_idx = tl.program_id(1) - logits_ptr += row_idx * logits_row_stride + logits_ptr += row_idx * logits_row_stride.to(tl.int64) loss_ptr += row_idx logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx labels_ptr += row_idx @@ -193,8 +193,8 @@ def _cross_entropy_backward( dloss_row_stride , logsumexp_ptr , labels_ptr , - VOCAB_SIZE , - BLOCK_SIZE : tl.constexpr(tl.int32), + VOCAB_SIZE : tl.constexpr, + BLOCK_SIZE : tl.constexpr, DO_SOFTCAPPING , SOFTCAP , DO_LOGIT_SCALING , @@ -218,7 +218,7 @@ def _cross_entropy_backward( row_idx = tl.program_id(0) block_idx = tl.program_id(1) - logits_ptr += row_idx * logits_row_stride + logits_ptr += row_idx * logits_row_stride.to(tl.int64) dloss_ptr += row_idx * dloss_row_stride col_offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = col_offsets < VOCAB_SIZE diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 837b4849f1..a260466915 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -92,6 +92,17 @@ # Stop "Special tokens have been added in the vocabulary, ..." import logging logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.CRITICAL+1) + +# Ignore logging messages +class HideLoggingMessage(logging.Filter): + def __init__(self, text): self.text = text + def filter(self, x): return not x.getMessage().startswith(self.text) +pass + +# The speedups for torchdynamo mostly come wih GPU Ampere or higher and which is not detected here. +import transformers.training_args.logger +transformers.training_args.logger.addFilter(HideLoggingMessage("The speedups")) + # ============================================= # ============================================= @@ -380,7 +391,7 @@ def is_big_gpu(index): import accelerate def torch_compile_kwargs(*args, **kwargs): print("Unsloth: Enabled auto compiling") - return {"dynamic" : True, "fullgraph" : False, "options" : torch_compile_options} + return {"dynamic" : True, "fullgraph" : False, "options" : torch_compile_options,} pass accelerate.utils.dataclasses.TorchDynamoPlugin.to_kwargs = torch_compile_kwargs From 9bae6e21492b95b512d8f76375a4510ed1d7ec1c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sun, 3 Nov 2024 23:55:39 -0800 Subject: [PATCH 115/209] Update _utils.py --- unsloth/models/_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index a260466915..c309e1c1cd 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -100,8 +100,9 @@ def filter(self, x): return not x.getMessage().startswith(self.text) pass # The speedups for torchdynamo mostly come wih GPU Ampere or higher and which is not detected here. -import transformers.training_args.logger -transformers.training_args.logger.addFilter(HideLoggingMessage("The speedups")) +from transformers.training_args import logger as transformers_training_args_logger +transformers_training_args_logger.addFilter(HideLoggingMessage("The speedups")) +del transformers_training_args_logger # ============================================= From 5123623d3e45a2895769dcb4862812dbd8e6ada3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 Nov 2024 00:01:04 -0800 Subject: [PATCH 116/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 30 +++++++++++++-------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 11e582711d..91b57fe85d 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -38,10 +38,10 @@ def _cross_entropy_forward( labels_ptr , VOCAB_SIZE : tl.constexpr, BLOCK_SIZE : tl.constexpr, - DO_SOFTCAPPING , - SOFTCAP , - DO_LOGIT_SCALING , - LOGIT_SCALE , + DO_SOFTCAPPING : tl.constexpr, + SOFTCAP : tl.constexpr, + DO_LOGIT_SCALING : tl.constexpr, + LOGIT_SCALE : tl.constexpr, ): """ Cross Entropy Loss = 1/n sum [ -yi log(Pi) ] @@ -64,7 +64,7 @@ def _cross_entropy_forward( This ensures exp(x - max(x))'s maximum is 1 as exp(0) = 1. """ row_idx = tl.program_id(0) - logits_ptr += row_idx * logits_row_stride.to(tl.int64) + logits_ptr += row_idx * tl.cast(logits_row_stride, tl.int64) loss_ptr += row_idx logsumexp_ptr += row_idx labels_ptr += row_idx @@ -112,10 +112,10 @@ def _chunked_cross_entropy_forward( VOCAB_SIZE : tl.constexpr, N_CHUNKS : tl.constexpr, BLOCK_SIZE : tl.constexpr, - DO_SOFTCAPPING , - SOFTCAP , - DO_LOGIT_SCALING , - LOGIT_SCALE , + DO_SOFTCAPPING : tl.constexpr, + SOFTCAP : tl.constexpr, + DO_LOGIT_SCALING : tl.constexpr, + LOGIT_SCALE : tl.constexpr, ): """ 256K vocab divided in 4 chunks @@ -143,7 +143,7 @@ def _chunked_cross_entropy_forward( """ row_idx = tl.program_id(0) chunk_idx = tl.program_id(1) - logits_ptr += row_idx * logits_row_stride.to(tl.int64) + logits_ptr += row_idx * tl.cast(logits_row_stride, tl.int64) loss_ptr += row_idx logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx labels_ptr += row_idx @@ -195,10 +195,10 @@ def _cross_entropy_backward( labels_ptr , VOCAB_SIZE : tl.constexpr, BLOCK_SIZE : tl.constexpr, - DO_SOFTCAPPING , - SOFTCAP , - DO_LOGIT_SCALING , - LOGIT_SCALE , + DO_SOFTCAPPING : tl.constexpr, + SOFTCAP : tl.constexpr, + DO_LOGIT_SCALING : tl.constexpr, + LOGIT_SCALE : tl.constexpr, ): """ CE_i = -y log(P) = y * (log[sum(exp(x))] - x) @@ -218,7 +218,7 @@ def _cross_entropy_backward( row_idx = tl.program_id(0) block_idx = tl.program_id(1) - logits_ptr += row_idx * logits_row_stride.to(tl.int64) + logits_ptr += row_idx * tl.cast(logits_row_stride, tl.int64) dloss_ptr += row_idx * dloss_row_stride col_offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = col_offsets < VOCAB_SIZE From 4b1d9e262216608fcc6db9172230f467656fb785 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 Nov 2024 00:03:24 -0800 Subject: [PATCH 117/209] constexpr --- unsloth/kernels/cross_entropy_loss.py | 12 ++++++------ unsloth/models/_utils.py | 6 +++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 91b57fe85d..9cf7ddc36d 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -38,9 +38,9 @@ def _cross_entropy_forward( labels_ptr , VOCAB_SIZE : tl.constexpr, BLOCK_SIZE : tl.constexpr, - DO_SOFTCAPPING : tl.constexpr, + DO_SOFTCAPPING , SOFTCAP : tl.constexpr, - DO_LOGIT_SCALING : tl.constexpr, + DO_LOGIT_SCALING , LOGIT_SCALE : tl.constexpr, ): """ @@ -112,9 +112,9 @@ def _chunked_cross_entropy_forward( VOCAB_SIZE : tl.constexpr, N_CHUNKS : tl.constexpr, BLOCK_SIZE : tl.constexpr, - DO_SOFTCAPPING : tl.constexpr, + DO_SOFTCAPPING , SOFTCAP : tl.constexpr, - DO_LOGIT_SCALING : tl.constexpr, + DO_LOGIT_SCALING , LOGIT_SCALE : tl.constexpr, ): """ @@ -195,9 +195,9 @@ def _cross_entropy_backward( labels_ptr , VOCAB_SIZE : tl.constexpr, BLOCK_SIZE : tl.constexpr, - DO_SOFTCAPPING : tl.constexpr, + DO_SOFTCAPPING , SOFTCAP : tl.constexpr, - DO_LOGIT_SCALING : tl.constexpr, + DO_LOGIT_SCALING , LOGIT_SCALE : tl.constexpr, ): """ diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index c309e1c1cd..dd37d26ae4 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -366,8 +366,8 @@ def _is_openai_available(): return False # ============================================= # Torch compile settings -UNSLOTH_COMPILE_DEBUG = True - +UNSLOTH_COMPILE_DEBUG = "UNSLOTH_COMPILE_DEBUG" in os.environ +UNSLOTH_COMPILE_MAXIMUM = "UNSLOTH_COMPILE_MAXIMUM" in os.environ # Just remove max_autotune_gemm warning import functools @functools.lru_cache(None) @@ -379,7 +379,7 @@ def is_big_gpu(index): return True import torch._inductor.utils torch._inductor.utils.is_big_gpu = is_big_gpu -patch_torch_compile(debug = UNSLOTH_COMPILE_DEBUG, O3 = False) +patch_torch_compile(debug = UNSLOTH_COMPILE_DEBUG, O3 = UNSLOTH_COMPILE_MAXIMUM) torch_compile_options = { "epilogue_fusion" : True, From 7d5111a40c13de36aa675d19fcb9c6c6a9deb5de Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 Nov 2024 00:05:51 -0800 Subject: [PATCH 118/209] constexpr --- unsloth/kernels/cross_entropy_loss.py | 12 ++++++------ unsloth/models/_utils.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 9cf7ddc36d..d1b8ae8275 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -39,9 +39,9 @@ def _cross_entropy_forward( VOCAB_SIZE : tl.constexpr, BLOCK_SIZE : tl.constexpr, DO_SOFTCAPPING , - SOFTCAP : tl.constexpr, + SOFTCAP , DO_LOGIT_SCALING , - LOGIT_SCALE : tl.constexpr, + LOGIT_SCALE , ): """ Cross Entropy Loss = 1/n sum [ -yi log(Pi) ] @@ -113,9 +113,9 @@ def _chunked_cross_entropy_forward( N_CHUNKS : tl.constexpr, BLOCK_SIZE : tl.constexpr, DO_SOFTCAPPING , - SOFTCAP : tl.constexpr, + SOFTCAP , DO_LOGIT_SCALING , - LOGIT_SCALE : tl.constexpr, + LOGIT_SCALE , ): """ 256K vocab divided in 4 chunks @@ -196,9 +196,9 @@ def _cross_entropy_backward( VOCAB_SIZE : tl.constexpr, BLOCK_SIZE : tl.constexpr, DO_SOFTCAPPING , - SOFTCAP : tl.constexpr, + SOFTCAP , DO_LOGIT_SCALING , - LOGIT_SCALE : tl.constexpr, + LOGIT_SCALE , ): """ CE_i = -y log(P) = y * (log[sum(exp(x))] - x) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index dd37d26ae4..5c099548d6 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -366,7 +366,7 @@ def _is_openai_available(): return False # ============================================= # Torch compile settings -UNSLOTH_COMPILE_DEBUG = "UNSLOTH_COMPILE_DEBUG" in os.environ +UNSLOTH_COMPILE_DEBUG = True #"UNSLOTH_COMPILE_DEBUG" in os.environ UNSLOTH_COMPILE_MAXIMUM = "UNSLOTH_COMPILE_MAXIMUM" in os.environ # Just remove max_autotune_gemm warning import functools From dff5a5250eee1a751fc4d1efb886b9989ee56274 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 Nov 2024 00:07:52 -0800 Subject: [PATCH 119/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index d1b8ae8275..abe08d7d06 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -110,7 +110,7 @@ def _chunked_cross_entropy_forward( logsumexp_ptr , labels_ptr , VOCAB_SIZE : tl.constexpr, - N_CHUNKS : tl.constexpr, + N_CHUNKS , BLOCK_SIZE : tl.constexpr, DO_SOFTCAPPING , SOFTCAP , From 969d1bd287e8d3c3f31c2f4ee948dc44ca61913b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 Nov 2024 00:32:35 -0800 Subject: [PATCH 120/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index abe08d7d06..f0193c74d8 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -36,7 +36,7 @@ def _cross_entropy_forward( loss_ptr , logsumexp_ptr , labels_ptr , - VOCAB_SIZE : tl.constexpr, + VOCAB_SIZE , BLOCK_SIZE : tl.constexpr, DO_SOFTCAPPING , SOFTCAP , @@ -109,7 +109,7 @@ def _chunked_cross_entropy_forward( loss_ptr , logsumexp_ptr , labels_ptr , - VOCAB_SIZE : tl.constexpr, + VOCAB_SIZE , N_CHUNKS , BLOCK_SIZE : tl.constexpr, DO_SOFTCAPPING , @@ -193,7 +193,7 @@ def _cross_entropy_backward( dloss_row_stride , logsumexp_ptr , labels_ptr , - VOCAB_SIZE : tl.constexpr, + VOCAB_SIZE , BLOCK_SIZE : tl.constexpr, DO_SOFTCAPPING , SOFTCAP , From 4b5847ffd53fb4dc33e982b31c97d45ad9b0383f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 Nov 2024 01:20:31 -0800 Subject: [PATCH 121/209] Update _utils.py --- unsloth/models/_utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 5c099548d6..7a8dcedada 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -41,14 +41,17 @@ "torch_amp_custom_bwd", "accelerate_old_send_to_device", "accelerate_new_send_to_device", - "patch_gradient_checkpointing", - "unpatch_gradient_checkpointing", "patch_gradient_accumulation_fix", "patch_compiling_bitsandbytes", "patch_regional_compilation", "patch_layernorm", "patch_torch_compile", "patch_model_and_tokenizer", + + "patch_unsloth_gradient_checkpointing", + "unpatch_unsloth_gradient_checkpointing", + "patch_gradient_checkpointing", + "unpatch_gradient_checkpointing", ] import torch From 766bf1ef658a01cf793deba145278331cbaa689a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 Nov 2024 01:23:24 -0800 Subject: [PATCH 122/209] Update _utils.py --- unsloth/models/_utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 7a8dcedada..182fdb1dd4 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -75,6 +75,11 @@ from unsloth_zoo.gradient_checkpointing import ( Unsloth_Offloaded_Gradient_Checkpointer, unsloth_offloaded_gradient_checkpoint, + patch_unsloth_gradient_checkpointing, + unpatch_unsloth_gradient_checkpointing, + + Unsloth_Gradient_Checkpointer, + unsloth_gradient_checkpoint, patch_gradient_checkpointing, unpatch_gradient_checkpointing, ) From 646f1b7fcaae90a693707b2bd16519aa7cf13c8a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 Nov 2024 17:56:16 -0800 Subject: [PATCH 123/209] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 182fdb1dd4..7ed9f7360e 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -375,7 +375,7 @@ def _is_openai_available(): return False # ============================================= # Torch compile settings UNSLOTH_COMPILE_DEBUG = True #"UNSLOTH_COMPILE_DEBUG" in os.environ -UNSLOTH_COMPILE_MAXIMUM = "UNSLOTH_COMPILE_MAXIMUM" in os.environ +UNSLOTH_COMPILE_MAXIMUM = True #"UNSLOTH_COMPILE_MAXIMUM" in os.environ # Just remove max_autotune_gemm warning import functools @functools.lru_cache(None) From 97f37ace9678df055f0027c6ba31e986550755fe Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 Nov 2024 19:45:10 -0800 Subject: [PATCH 124/209] CE --- unsloth/kernels/cross_entropy_loss.py | 2 ++ unsloth/models/_utils.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index f0193c74d8..9cc983d803 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -348,6 +348,8 @@ def backward(ctx, dlosses): div, mod = divmod(vocab_size, BLOCK_SIZE) n_blocks : int = div + (mod != 0) + print("111") + _cross_entropy_backward[(n_rows, n_blocks,)]( logits, logits.stride(0), dlosses, dlosses.stride(0), diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 7ed9f7360e..cf34767d7b 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -375,7 +375,7 @@ def _is_openai_available(): return False # ============================================= # Torch compile settings UNSLOTH_COMPILE_DEBUG = True #"UNSLOTH_COMPILE_DEBUG" in os.environ -UNSLOTH_COMPILE_MAXIMUM = True #"UNSLOTH_COMPILE_MAXIMUM" in os.environ +UNSLOTH_COMPILE_MAXIMUM = False #"UNSLOTH_COMPILE_MAXIMUM" in os.environ # Just remove max_autotune_gemm warning import functools @functools.lru_cache(None) From cc563fac522bf00f6b28eeb9ac704f65a13d708a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 Nov 2024 19:48:09 -0800 Subject: [PATCH 125/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 9cc983d803..f0193c74d8 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -348,8 +348,6 @@ def backward(ctx, dlosses): div, mod = divmod(vocab_size, BLOCK_SIZE) n_blocks : int = div + (mod != 0) - print("111") - _cross_entropy_backward[(n_rows, n_blocks,)]( logits, logits.stride(0), dlosses, dlosses.stride(0), From f643148ea5f2b760a1c9ab1db573e53ac0c7b613 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 Nov 2024 21:47:48 -0800 Subject: [PATCH 126/209] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index cf34767d7b..7ed9f7360e 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -375,7 +375,7 @@ def _is_openai_available(): return False # ============================================= # Torch compile settings UNSLOTH_COMPILE_DEBUG = True #"UNSLOTH_COMPILE_DEBUG" in os.environ -UNSLOTH_COMPILE_MAXIMUM = False #"UNSLOTH_COMPILE_MAXIMUM" in os.environ +UNSLOTH_COMPILE_MAXIMUM = True #"UNSLOTH_COMPILE_MAXIMUM" in os.environ # Just remove max_autotune_gemm warning import functools @functools.lru_cache(None) From f28d7f6f678ea2a241171a5c438b8363cce10919 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 Nov 2024 22:08:49 -0800 Subject: [PATCH 127/209] Update llama.py --- unsloth/models/llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 3c4d8f3b38..c8bf5f8894 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1028,7 +1028,6 @@ def _CausalLM_fast_forward( pass -@torch._disable_dynamo def PeftModelForCausalLM_fast_forward( self, input_ids=None, From d8103e16b1c116e98c01e77c2f14f3eebeb93437 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 Nov 2024 22:10:52 -0800 Subject: [PATCH 128/209] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 7ed9f7360e..cf34767d7b 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -375,7 +375,7 @@ def _is_openai_available(): return False # ============================================= # Torch compile settings UNSLOTH_COMPILE_DEBUG = True #"UNSLOTH_COMPILE_DEBUG" in os.environ -UNSLOTH_COMPILE_MAXIMUM = True #"UNSLOTH_COMPILE_MAXIMUM" in os.environ +UNSLOTH_COMPILE_MAXIMUM = False #"UNSLOTH_COMPILE_MAXIMUM" in os.environ # Just remove max_autotune_gemm warning import functools @functools.lru_cache(None) From b9e1a49d2b01dd8557593c3f950b6ce74c4616d9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 Nov 2024 22:38:47 -0800 Subject: [PATCH 129/209] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index c0fb222b8a..b52afa18c2 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -130,7 +130,7 @@ def _gemma_rms_layernorm_forward( class Fast_RMS_Layernorm(torch.autograd.Function): @staticmethod - def forward(ctx, X, W, eps :float, gemma : bool = False): + def forward(ctx, X, W, eps : float, gemma : bool = False): shape = X.shape dim : int = shape[-1] X = X.view(-1, dim) @@ -144,16 +144,26 @@ def forward(ctx, X, W, eps :float, gemma : bool = False): Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda:0") r = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0") - fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward - fx[(n_rows,)]( - Y, Y.stride(0), - X, X.stride(0), - W, W.stride(0), - r, r.stride(0), - n_cols, eps, - BLOCK_SIZE = BLOCK_SIZE, - num_warps = num_warps, - ) + if not gemma: + _rms_layernorm_forward[(n_rows,)]( + Y, Y.stride(0), + X, X.stride(0), + W, W.stride(0), + r, r.stride(0), + n_cols, eps, + BLOCK_SIZE = BLOCK_SIZE, + num_warps = num_warps, + ) + else: + _gemma_rms_layernorm_forward[(n_rows,)]( + Y, Y.stride(0), + X, X.stride(0), + W, W.stride(0), + r, r.stride(0), + n_cols, eps, + BLOCK_SIZE = BLOCK_SIZE, + num_warps = num_warps, + ) ctx.eps = eps ctx.BLOCK_SIZE = BLOCK_SIZE ctx.num_warps = num_warps From 56af302c45a8f72683b2467d5193a81a57a7f1a3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 Nov 2024 23:01:14 -0800 Subject: [PATCH 130/209] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index b52afa18c2..2a9fbccd9e 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -20,12 +20,17 @@ @triton.jit def _rms_layernorm_forward( - Y, Y_row_stride, - X, X_row_stride, - W, W_row_stride, - r, r_row_stride, - n_cols, eps, - BLOCK_SIZE : tl.constexpr + Y, + Y_row_stride, + X, + X_row_stride, + W, + W_row_stride, + r, + r_row_stride, + n_cols, + eps, + BLOCK_SIZE : tl.constexpr, ): """ Fast RMS Layernorm kernel @@ -150,7 +155,8 @@ def forward(ctx, X, W, eps : float, gemma : bool = False): X, X.stride(0), W, W.stride(0), r, r.stride(0), - n_cols, eps, + n_cols = int(n_cols), + eps = float(eps), BLOCK_SIZE = BLOCK_SIZE, num_warps = num_warps, ) From a3c84a385119b43bae86dc7c36d0ffd2a3441690 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 Nov 2024 23:04:34 -0800 Subject: [PATCH 131/209] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index 2a9fbccd9e..0dc4be6ba7 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -149,7 +149,7 @@ def forward(ctx, X, W, eps : float, gemma : bool = False): Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda:0") r = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0") - if not gemma: + if gemma == False: _rms_layernorm_forward[(n_rows,)]( Y, Y.stride(0), X, X.stride(0), From f7d5c565847ef83d4a7888f7c2251661cf3a8803 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 Nov 2024 23:06:13 -0800 Subject: [PATCH 132/209] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index 0dc4be6ba7..4074d3a502 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -138,7 +138,7 @@ class Fast_RMS_Layernorm(torch.autograd.Function): def forward(ctx, X, W, eps : float, gemma : bool = False): shape = X.shape dim : int = shape[-1] - X = X.view(-1, dim) + X : torch.Tensor = X.view(-1, dim) n_rows : int n_cols : int n_rows, n_cols = X.shape From 8496ff69e1d4217e915f5446873424c55403cb8c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 Nov 2024 23:08:05 -0800 Subject: [PATCH 133/209] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 42 ++++++++++++++++---------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index 4074d3a502..7a788f36a7 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -149,27 +149,27 @@ def forward(ctx, X, W, eps : float, gemma : bool = False): Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda:0") r = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0") - if gemma == False: - _rms_layernorm_forward[(n_rows,)]( - Y, Y.stride(0), - X, X.stride(0), - W, W.stride(0), - r, r.stride(0), - n_cols = int(n_cols), - eps = float(eps), - BLOCK_SIZE = BLOCK_SIZE, - num_warps = num_warps, - ) - else: - _gemma_rms_layernorm_forward[(n_rows,)]( - Y, Y.stride(0), - X, X.stride(0), - W, W.stride(0), - r, r.stride(0), - n_cols, eps, - BLOCK_SIZE = BLOCK_SIZE, - num_warps = num_warps, - ) + # if gemma == False: + _rms_layernorm_forward[(n_rows,)]( + Y, Y.stride(0), + X, X.stride(0), + W, W.stride(0), + r, r.stride(0), + n_cols = int(n_cols), + eps = float(eps), + BLOCK_SIZE = BLOCK_SIZE, + num_warps = num_warps, + ) + # else: + # _gemma_rms_layernorm_forward[(n_rows,)]( + # Y, Y.stride(0), + # X, X.stride(0), + # W, W.stride(0), + # r, r.stride(0), + # n_cols, eps, + # BLOCK_SIZE = BLOCK_SIZE, + # num_warps = num_warps, + # ) ctx.eps = eps ctx.BLOCK_SIZE = BLOCK_SIZE ctx.num_warps = num_warps From 2909eaf08e92f81aa509d958d2c5bb17f2ddc870 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 Nov 2024 23:11:03 -0800 Subject: [PATCH 134/209] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 44 ++++++++++++++++---------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index 7a788f36a7..b6ffa5fee9 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -138,7 +138,7 @@ class Fast_RMS_Layernorm(torch.autograd.Function): def forward(ctx, X, W, eps : float, gemma : bool = False): shape = X.shape dim : int = shape[-1] - X : torch.Tensor = X.view(-1, dim) + X = X.view(-1, dim) n_rows : int n_cols : int n_rows, n_cols = X.shape @@ -149,27 +149,27 @@ def forward(ctx, X, W, eps : float, gemma : bool = False): Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda:0") r = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0") - # if gemma == False: - _rms_layernorm_forward[(n_rows,)]( - Y, Y.stride(0), - X, X.stride(0), - W, W.stride(0), - r, r.stride(0), - n_cols = int(n_cols), - eps = float(eps), - BLOCK_SIZE = BLOCK_SIZE, - num_warps = num_warps, - ) - # else: - # _gemma_rms_layernorm_forward[(n_rows,)]( - # Y, Y.stride(0), - # X, X.stride(0), - # W, W.stride(0), - # r, r.stride(0), - # n_cols, eps, - # BLOCK_SIZE = BLOCK_SIZE, - # num_warps = num_warps, - # ) + if not gemma: + _rms_layernorm_forward[(n_rows,)]( + Y, Y.stride(0), + X, X.stride(0), + W, W.stride(0), + r, r.stride(0), + n_cols = int(n_cols), + eps = float(eps), + BLOCK_SIZE = BLOCK_SIZE, + num_warps = 16, + ) + else: + _gemma_rms_layernorm_forward[(n_rows,)]( + Y, Y.stride(0), + X, X.stride(0), + W, W.stride(0), + r, r.stride(0), + n_cols, eps, + BLOCK_SIZE = BLOCK_SIZE, + num_warps = num_warps, + ) ctx.eps = eps ctx.BLOCK_SIZE = BLOCK_SIZE ctx.num_warps = num_warps From afc8af69f5a883e0699d9f9abd4c781da5ccef12 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 Nov 2024 23:16:01 -0800 Subject: [PATCH 135/209] Update utils.py --- unsloth/kernels/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index a8c20c75a4..b394d122fd 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -13,7 +13,7 @@ # limitations under the License. import triton -MAX_FUSED_SIZE = 65536 +MAX_FUSED_SIZE : int = 65536 next_power_of_2 = triton.next_power_of_2 # torch.cuda.amp.custom_fwd is deprecated >= 2.4 @@ -40,12 +40,12 @@ pass -def calculate_settings(n): - BLOCK_SIZE = next_power_of_2(n) +def calculate_settings(n : int) -> (int, int,): + BLOCK_SIZE : int = next_power_of_2(n) if BLOCK_SIZE > MAX_FUSED_SIZE: raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\ f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.") - num_warps = 4 + num_warps : int = 4 if BLOCK_SIZE >= 32768: num_warps = 32 elif BLOCK_SIZE >= 8192: num_warps = 16 elif BLOCK_SIZE >= 2048: num_warps = 8 From 2d8d1e1e2da19235ca29641040f95619e690b16a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 Nov 2024 23:19:02 -0800 Subject: [PATCH 136/209] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index b6ffa5fee9..23b7342c87 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -144,7 +144,7 @@ def forward(ctx, X, W, eps : float, gemma : bool = False): n_rows, n_cols = X.shape BLOCK_SIZE : int num_warps : int - BLOCK_SIZE, num_warps = calculate_settings(n_cols) + # BLOCK_SIZE, num_warps = calculate_settings(n_cols) Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda:0") r = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0") @@ -157,7 +157,7 @@ def forward(ctx, X, W, eps : float, gemma : bool = False): r, r.stride(0), n_cols = int(n_cols), eps = float(eps), - BLOCK_SIZE = BLOCK_SIZE, + BLOCK_SIZE = 4096, num_warps = 16, ) else: From ecc1ad223c321c8ce009f702166397eba105b008 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 Nov 2024 23:20:32 -0800 Subject: [PATCH 137/209] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index 23b7342c87..821f6b540d 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -171,8 +171,8 @@ def forward(ctx, X, W, eps : float, gemma : bool = False): num_warps = num_warps, ) ctx.eps = eps - ctx.BLOCK_SIZE = BLOCK_SIZE - ctx.num_warps = num_warps + ctx.BLOCK_SIZE = 4096 + ctx.num_warps = 16 ctx.GEMMA = gemma ctx.save_for_backward(X, W, r) return Y.view(*shape) From ae7cb78e9122b26536e8c25b1c79a2bd642e86da Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 Nov 2024 23:24:33 -0800 Subject: [PATCH 138/209] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index 821f6b540d..53f1efa9f5 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -16,6 +16,7 @@ import triton.language as tl import torch from .utils import calculate_settings +next_power_of_2 = triton.next_power_of_2 @triton.jit @@ -142,9 +143,15 @@ def forward(ctx, X, W, eps : float, gemma : bool = False): n_rows : int n_cols : int n_rows, n_cols = X.shape - BLOCK_SIZE : int - num_warps : int - # BLOCK_SIZE, num_warps = calculate_settings(n_cols) + BLOCK_SIZE : int = next_power_of_2(n_cols) + MAX_FUSED_SIZE : int = 65536 + if BLOCK_SIZE > MAX_FUSED_SIZE: + raise RuntimeError(f"Cannot launch Triton kernel since n = {n_cols} exceeds "\ + f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.") + num_warps : int = 4 + if BLOCK_SIZE >= 32768: num_warps = 32 + elif BLOCK_SIZE >= 8192: num_warps = 16 + elif BLOCK_SIZE >= 2048: num_warps = 8 Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda:0") r = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0") @@ -157,8 +164,8 @@ def forward(ctx, X, W, eps : float, gemma : bool = False): r, r.stride(0), n_cols = int(n_cols), eps = float(eps), - BLOCK_SIZE = 4096, - num_warps = 16, + BLOCK_SIZE = BLOCK_SIZE, + num_warps = num_warps, ) else: _gemma_rms_layernorm_forward[(n_rows,)]( @@ -171,8 +178,8 @@ def forward(ctx, X, W, eps : float, gemma : bool = False): num_warps = num_warps, ) ctx.eps = eps - ctx.BLOCK_SIZE = 4096 - ctx.num_warps = 16 + ctx.BLOCK_SIZE = BLOCK_SIZE + ctx.num_warps = num_warps ctx.GEMMA = gemma ctx.save_for_backward(X, W, r) return Y.view(*shape) From 22da2662197fd0feb2843eca12036e6027b28670 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 Nov 2024 23:27:50 -0800 Subject: [PATCH 139/209] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index 53f1efa9f5..a7385cb38e 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -15,8 +15,11 @@ import triton import triton.language as tl import torch -from .utils import calculate_settings -next_power_of_2 = triton.next_power_of_2 +from .utils import ( + calculate_settings, + MAX_FUSED_SIZE, + next_power_of_2, +) @triton.jit @@ -143,8 +146,8 @@ def forward(ctx, X, W, eps : float, gemma : bool = False): n_rows : int n_cols : int n_rows, n_cols = X.shape - BLOCK_SIZE : int = next_power_of_2(n_cols) - MAX_FUSED_SIZE : int = 65536 + + BLOCK_SIZE : int = n_cols if BLOCK_SIZE > MAX_FUSED_SIZE: raise RuntimeError(f"Cannot launch Triton kernel since n = {n_cols} exceeds "\ f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.") From beb6854e76db7ebe0419735f0b1e58683da3ad6a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 Nov 2024 23:28:57 -0800 Subject: [PATCH 140/209] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index a7385cb38e..4079921f91 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -15,11 +15,7 @@ import triton import triton.language as tl import torch -from .utils import ( - calculate_settings, - MAX_FUSED_SIZE, - next_power_of_2, -) +from .utils import calculate_settings @triton.jit @@ -147,19 +143,13 @@ def forward(ctx, X, W, eps : float, gemma : bool = False): n_cols : int n_rows, n_cols = X.shape - BLOCK_SIZE : int = n_cols - if BLOCK_SIZE > MAX_FUSED_SIZE: - raise RuntimeError(f"Cannot launch Triton kernel since n = {n_cols} exceeds "\ - f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.") - num_warps : int = 4 - if BLOCK_SIZE >= 32768: num_warps = 32 - elif BLOCK_SIZE >= 8192: num_warps = 16 - elif BLOCK_SIZE >= 2048: num_warps = 8 - Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda:0") r = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0") + BLOCK_SIZE : int + num_warps : int if not gemma: + BLOCK_SIZE, num_warps = calculate_settings(n_cols) _rms_layernorm_forward[(n_rows,)]( Y, Y.stride(0), X, X.stride(0), From 14c3d2f900aa3cdb15d5c9e85ddd187f9fe532d5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 Nov 2024 23:30:55 -0800 Subject: [PATCH 141/209] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index 4079921f91..3176d4e358 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -142,14 +142,14 @@ def forward(ctx, X, W, eps : float, gemma : bool = False): n_rows : int n_cols : int n_rows, n_cols = X.shape + BLOCK_SIZE : int + num_warps : int + BLOCK_SIZE, num_warps = calculate_settings(n_cols) Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda:0") r = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0") - BLOCK_SIZE : int - num_warps : int if not gemma: - BLOCK_SIZE, num_warps = calculate_settings(n_cols) _rms_layernorm_forward[(n_rows,)]( Y, Y.stride(0), X, X.stride(0), @@ -158,7 +158,7 @@ def forward(ctx, X, W, eps : float, gemma : bool = False): n_cols = int(n_cols), eps = float(eps), BLOCK_SIZE = BLOCK_SIZE, - num_warps = num_warps, + num_warps = 16, ) else: _gemma_rms_layernorm_forward[(n_rows,)]( @@ -168,7 +168,7 @@ def forward(ctx, X, W, eps : float, gemma : bool = False): r, r.stride(0), n_cols, eps, BLOCK_SIZE = BLOCK_SIZE, - num_warps = num_warps, + num_warps = 16, ) ctx.eps = eps ctx.BLOCK_SIZE = BLOCK_SIZE From ef4b079b87cb3b3956b34700d6de78ef81912f9c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 Nov 2024 23:33:08 -0800 Subject: [PATCH 142/209] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index 3176d4e358..585684567c 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -206,9 +206,9 @@ def backward(ctx, dY): pass -def fast_rms_layernorm(layernorm, X, gemma = False): +def fast_rms_layernorm(layernorm, X, gemma : bool = False): W = layernorm.weight - eps = layernorm.variance_epsilon if \ + eps : float = layernorm.variance_epsilon if \ hasattr(layernorm, "variance_epsilon") \ else layernorm.eps out = Fast_RMS_Layernorm.apply(X, W, eps, gemma) From ef684f8d23ba10fd96b6f16cff1681c0421862cd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 Nov 2024 23:34:11 -0800 Subject: [PATCH 143/209] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index 585684567c..c331844e16 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -157,7 +157,7 @@ def forward(ctx, X, W, eps : float, gemma : bool = False): r, r.stride(0), n_cols = int(n_cols), eps = float(eps), - BLOCK_SIZE = BLOCK_SIZE, + BLOCK_SIZE = triton.next_power_of_2(n_cols), num_warps = 16, ) else: @@ -167,7 +167,7 @@ def forward(ctx, X, W, eps : float, gemma : bool = False): W, W.stride(0), r, r.stride(0), n_cols, eps, - BLOCK_SIZE = BLOCK_SIZE, + BLOCK_SIZE = triton.next_power_of_2(n_cols), num_warps = 16, ) ctx.eps = eps From 3e4c42f79ca041d388303880fef3a75e4728071e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 Nov 2024 23:39:03 -0800 Subject: [PATCH 144/209] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index c331844e16..ae16b40925 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -146,8 +146,8 @@ def forward(ctx, X, W, eps : float, gemma : bool = False): num_warps : int BLOCK_SIZE, num_warps = calculate_settings(n_cols) - Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda:0") - r = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0") + Y : torch.Tensor = torch.empty(X.shape, dtype = X.dtype, device = "cuda:0") + r : torch.Tensor = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0") if not gemma: _rms_layernorm_forward[(n_rows,)]( From 8f825eb2845518a974245e99d80f713a0c97a1ee Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 Nov 2024 23:45:25 -0800 Subject: [PATCH 145/209] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 46 ++++++++++++++++---------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index ae16b40925..c6aecc6d2f 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -138,7 +138,7 @@ class Fast_RMS_Layernorm(torch.autograd.Function): def forward(ctx, X, W, eps : float, gemma : bool = False): shape = X.shape dim : int = shape[-1] - X = X.view(-1, dim) + # X = X.view(-1, dim) n_rows : int n_cols : int n_rows, n_cols = X.shape @@ -149,33 +149,33 @@ def forward(ctx, X, W, eps : float, gemma : bool = False): Y : torch.Tensor = torch.empty(X.shape, dtype = X.dtype, device = "cuda:0") r : torch.Tensor = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0") - if not gemma: - _rms_layernorm_forward[(n_rows,)]( - Y, Y.stride(0), - X, X.stride(0), - W, W.stride(0), - r, r.stride(0), - n_cols = int(n_cols), - eps = float(eps), - BLOCK_SIZE = triton.next_power_of_2(n_cols), - num_warps = 16, - ) - else: - _gemma_rms_layernorm_forward[(n_rows,)]( - Y, Y.stride(0), - X, X.stride(0), - W, W.stride(0), - r, r.stride(0), - n_cols, eps, - BLOCK_SIZE = triton.next_power_of_2(n_cols), - num_warps = 16, - ) + # if not gemma: + _rms_layernorm_forward[(n_rows,)]( + Y, Y.stride(0), + X, X.stride(0), + W, W.stride(0), + r, r.stride(0), + n_cols = int(n_cols), + eps = float(eps), + BLOCK_SIZE = triton.next_power_of_2(n_cols), + num_warps = 16, + ) + # else: + # _gemma_rms_layernorm_forward[(n_rows,)]( + # Y, Y.stride(0), + # X, X.stride(0), + # W, W.stride(0), + # r, r.stride(0), + # n_cols, eps, + # BLOCK_SIZE = triton.next_power_of_2(n_cols), + # num_warps = 16, + # ) ctx.eps = eps ctx.BLOCK_SIZE = BLOCK_SIZE ctx.num_warps = num_warps ctx.GEMMA = gemma ctx.save_for_backward(X, W, r) - return Y.view(*shape) + return Y#.view(*shape) pass @staticmethod From bd4ac7b21840d0c83464589f179eb16bce793c83 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 Nov 2024 23:47:38 -0800 Subject: [PATCH 146/209] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index c6aecc6d2f..287c8219b5 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -137,8 +137,8 @@ class Fast_RMS_Layernorm(torch.autograd.Function): @staticmethod def forward(ctx, X, W, eps : float, gemma : bool = False): shape = X.shape - dim : int = shape[-1] - # X = X.view(-1, dim) + # dim : int = shape[-1] + X = X.view(shape[0] * shape[1], shape[2]) n_rows : int n_cols : int n_rows, n_cols = X.shape From 6f38731dfbf36093561c220f218acc3aec772a64 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 4 Nov 2024 23:54:02 -0800 Subject: [PATCH 147/209] Update rms_layernorm.py --- unsloth/kernels/rms_layernorm.py | 55 ++++++++++++-------------------- 1 file changed, 20 insertions(+), 35 deletions(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index 287c8219b5..0846a09de6 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -20,17 +20,12 @@ @triton.jit def _rms_layernorm_forward( - Y, - Y_row_stride, - X, - X_row_stride, - W, - W_row_stride, - r, - r_row_stride, - n_cols, - eps, - BLOCK_SIZE : tl.constexpr, + Y, Y_row_stride, + X, X_row_stride, + W, W_row_stride, + r, r_row_stride, + n_cols, eps, + BLOCK_SIZE : tl.constexpr ): """ Fast RMS Layernorm kernel @@ -135,10 +130,10 @@ def _gemma_rms_layernorm_forward( class Fast_RMS_Layernorm(torch.autograd.Function): @staticmethod - def forward(ctx, X, W, eps : float, gemma : bool = False): + def forward(ctx, X, W, eps :float, gemma : bool = False): shape = X.shape - # dim : int = shape[-1] - X = X.view(shape[0] * shape[1], shape[2]) + dim : int = shape[-1] + X = X.view(-1, dim) n_rows : int n_cols : int n_rows, n_cols = X.shape @@ -146,36 +141,25 @@ def forward(ctx, X, W, eps : float, gemma : bool = False): num_warps : int BLOCK_SIZE, num_warps = calculate_settings(n_cols) - Y : torch.Tensor = torch.empty(X.shape, dtype = X.dtype, device = "cuda:0") - r : torch.Tensor = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0") + Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda:0") + r = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0") - # if not gemma: - _rms_layernorm_forward[(n_rows,)]( + fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward + fx[(n_rows,)]( Y, Y.stride(0), X, X.stride(0), W, W.stride(0), r, r.stride(0), - n_cols = int(n_cols), - eps = float(eps), - BLOCK_SIZE = triton.next_power_of_2(n_cols), - num_warps = 16, + n_cols, eps, + BLOCK_SIZE = BLOCK_SIZE, + num_warps = num_warps, ) - # else: - # _gemma_rms_layernorm_forward[(n_rows,)]( - # Y, Y.stride(0), - # X, X.stride(0), - # W, W.stride(0), - # r, r.stride(0), - # n_cols, eps, - # BLOCK_SIZE = triton.next_power_of_2(n_cols), - # num_warps = 16, - # ) ctx.eps = eps ctx.BLOCK_SIZE = BLOCK_SIZE ctx.num_warps = num_warps ctx.GEMMA = gemma ctx.save_for_backward(X, W, r) - return Y#.view(*shape) + return Y.view(*shape) pass @staticmethod @@ -206,9 +190,10 @@ def backward(ctx, dY): pass -def fast_rms_layernorm(layernorm, X, gemma : bool = False): +@torch.compiler.disable +def fast_rms_layernorm(layernorm, X, gemma = False): W = layernorm.weight - eps : float = layernorm.variance_epsilon if \ + eps = layernorm.variance_epsilon if \ hasattr(layernorm, "variance_epsilon") \ else layernorm.eps out = Fast_RMS_Layernorm.apply(X, W, eps, gemma) From 2df35d43e517916060266723254638756a8fc111 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 5 Nov 2024 00:05:56 -0800 Subject: [PATCH 148/209] typing --- unsloth/kernels/rms_layernorm.py | 17 +++++++++-------- unsloth/kernels/rope_embedding.py | 18 ++++++++++++++++-- 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index 0846a09de6..4b22f8c3e5 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -60,7 +60,7 @@ def _rms_layernorm_backward( X, X_row_stride, W, W_row_stride, r, r_row_stride, - dW, dW_row_stride, + # dW, dW_row_stride, n_cols, eps, GEMMA : tl.constexpr, BLOCK_SIZE : tl.constexpr, @@ -130,7 +130,7 @@ def _gemma_rms_layernorm_forward( class Fast_RMS_Layernorm(torch.autograd.Function): @staticmethod - def forward(ctx, X, W, eps :float, gemma : bool = False): + def forward(ctx, X : torch.Tensor, W : torch.Tensor, eps : float, gemma : bool = False): shape = X.shape dim : int = shape[-1] X = X.view(-1, dim) @@ -163,7 +163,7 @@ def forward(ctx, X, W, eps :float, gemma : bool = False): pass @staticmethod - def backward(ctx, dY): + def backward(ctx, dY : torch.Tensor): shape = dY.shape dim : int = shape[-1] dY = dY.view(-1, dim) @@ -171,14 +171,14 @@ def backward(ctx, dY): n_rows : int n_cols : int n_rows, n_cols = dY.shape - dW = X + # dW = X _rms_layernorm_backward[(n_rows,)]( dY, dY.stride(0), X, X .stride(0), W, W .stride(0), r, r .stride(0), - dW, dW.stride(0), + # dW, dW.stride(0), n_cols, ctx.eps, GEMMA = ctx.GEMMA, BLOCK_SIZE = ctx.BLOCK_SIZE, @@ -190,10 +190,11 @@ def backward(ctx, dY): pass +# [TODO] Unsure why RMS Layernorm is not torch.compiling properly @torch.compiler.disable -def fast_rms_layernorm(layernorm, X, gemma = False): - W = layernorm.weight - eps = layernorm.variance_epsilon if \ +def fast_rms_layernorm(layernorm, X : torch.Tensor, gemma : bool = False): + W : torch.Tensor = layernorm.weight + eps : float = layernorm.variance_epsilon if \ hasattr(layernorm, "variance_epsilon") \ else layernorm.eps out = Fast_RMS_Layernorm.apply(X, W, eps, gemma) diff --git a/unsloth/kernels/rope_embedding.py b/unsloth/kernels/rope_embedding.py index 2934ac41c9..44a7cda12f 100644 --- a/unsloth/kernels/rope_embedding.py +++ b/unsloth/kernels/rope_embedding.py @@ -18,7 +18,7 @@ from .utils import calculate_settings ROPE_GROUP_SIZE = 4 -@triton.heuristics({"BACKWARD_PASS": lambda args: args["BACKWARD_PASS"],}) +@triton.heuristics({"BACKWARD_PASS": lambda args: bool(args["BACKWARD_PASS"]),}) @triton.jit def _rope_embedding( Q, Q_row_stride, @@ -75,8 +75,14 @@ class Fast_RoPE_Embedding(torch.autograd.Function): @staticmethod def forward(ctx, Q, cos, sin): cos, sin = cos.squeeze(), sin.squeeze() + batch : int + seq_len : int + n_heads : int + head_dim : int batch, seq_len, n_heads, head_dim = Q.shape Q = Q.view(batch*seq_len, n_heads*head_dim) + n_rows : int + n_cols : int n_rows, n_cols = Q.shape assert(seq_len <= cos.shape[0]) @@ -85,8 +91,10 @@ def forward(ctx, Q, cos, sin): BLOCK_SIZE, num_warps = calculate_settings(head_dim//2) # (head_dim//2) # group_size = 4 # 4 or 8, too large group_size can hurt performance. + div : int + mod : int div, mod = divmod(n_heads, ROPE_GROUP_SIZE) - n_groups = div + (mod != 0) + n_groups : int = div + (mod != 0) _rope_embedding[(n_rows, n_groups, )]( Q, Q.stride(0), @@ -108,9 +116,15 @@ def forward(ctx, Q, cos, sin): @staticmethod def backward(ctx, dY): + batch : int + seq_len : int + n_heads : int + head_dim : int batch, seq_len, n_heads, head_dim = dY.shape dY = dY.reshape(batch*seq_len, n_heads*head_dim) # Must be reshape not view + n_rows : int + n_cols : int n_rows, n_cols = dY.shape cos = ctx.cos From 74d89d11552d24df9ee76c4cec94cd895c363c0c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 5 Nov 2024 00:08:16 -0800 Subject: [PATCH 149/209] Update rope_embedding.py --- unsloth/kernels/rope_embedding.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth/kernels/rope_embedding.py b/unsloth/kernels/rope_embedding.py index 44a7cda12f..a8173d2453 100644 --- a/unsloth/kernels/rope_embedding.py +++ b/unsloth/kernels/rope_embedding.py @@ -25,9 +25,9 @@ def _rope_embedding( cos, cos_row_stride, sin, sin_row_stride, seqlen, - head_dim : tl.constexpr, - n_heads : tl.constexpr, - BACKWARD_PASS : tl.constexpr, + head_dim, + n_heads, + BACKWARD_PASS, BLOCK_SIZE : tl.constexpr, ): """ From 98927ee333aed07f9566dd05aa8ed1c357d993ae Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 5 Nov 2024 00:09:11 -0800 Subject: [PATCH 150/209] types --- unsloth/kernels/rms_layernorm.py | 2 +- unsloth/kernels/rope_embedding.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index 4b22f8c3e5..10b435dd57 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -191,7 +191,7 @@ def backward(ctx, dY : torch.Tensor): # [TODO] Unsure why RMS Layernorm is not torch.compiling properly -@torch.compiler.disable +# @torch.compiler.disable def fast_rms_layernorm(layernorm, X : torch.Tensor, gemma : bool = False): W : torch.Tensor = layernorm.weight eps : float = layernorm.variance_epsilon if \ diff --git a/unsloth/kernels/rope_embedding.py b/unsloth/kernels/rope_embedding.py index a8173d2453..246055dc7b 100644 --- a/unsloth/kernels/rope_embedding.py +++ b/unsloth/kernels/rope_embedding.py @@ -16,7 +16,7 @@ import triton.language as tl import torch from .utils import calculate_settings -ROPE_GROUP_SIZE = 4 +ROPE_GROUP_SIZE : int = 4 @triton.heuristics({"BACKWARD_PASS": lambda args: bool(args["BACKWARD_PASS"]),}) @triton.jit From f3e2bd6e5ebb83c21d3bd07c3268d38cf7fb4418 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 5 Nov 2024 00:11:32 -0800 Subject: [PATCH 151/209] Disable compiling --- unsloth/kernels/rms_layernorm.py | 2 +- unsloth/kernels/rope_embedding.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/unsloth/kernels/rms_layernorm.py b/unsloth/kernels/rms_layernorm.py index 10b435dd57..4b22f8c3e5 100644 --- a/unsloth/kernels/rms_layernorm.py +++ b/unsloth/kernels/rms_layernorm.py @@ -191,7 +191,7 @@ def backward(ctx, dY : torch.Tensor): # [TODO] Unsure why RMS Layernorm is not torch.compiling properly -# @torch.compiler.disable +@torch.compiler.disable def fast_rms_layernorm(layernorm, X : torch.Tensor, gemma : bool = False): W : torch.Tensor = layernorm.weight eps : float = layernorm.variance_epsilon if \ diff --git a/unsloth/kernels/rope_embedding.py b/unsloth/kernels/rope_embedding.py index 246055dc7b..7fe15d0e3b 100644 --- a/unsloth/kernels/rope_embedding.py +++ b/unsloth/kernels/rope_embedding.py @@ -25,9 +25,9 @@ def _rope_embedding( cos, cos_row_stride, sin, sin_row_stride, seqlen, - head_dim, - n_heads, - BACKWARD_PASS, + head_dim : tl.constexpr, + n_heads : tl.constexpr, + BACKWARD_PASS : tl.constexpr, BLOCK_SIZE : tl.constexpr, ): """ @@ -144,7 +144,8 @@ def backward(ctx, dY): pass pass - +# [TODO] Unsure why RoPE Embedding is not torch.compiling properly +@torch.compiler.disable def fast_rope_embedding(Q, K, cos, sin): Q = Fast_RoPE_Embedding.apply(Q.transpose(1, 2), cos, sin).transpose(1, 2) K = Fast_RoPE_Embedding.apply(K.transpose(1, 2), cos, sin).transpose(1, 2) From c30bd2a2caaff79ca71f9c2df004fdef810f3a91 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 5 Nov 2024 00:15:56 -0800 Subject: [PATCH 152/209] Update _utils.py --- unsloth/models/_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index cf34767d7b..cdb28234c7 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -321,7 +321,8 @@ def _is_openai_available(): return False ) pass import xformers.ops.fmha as xformers -xformers_attention = xformers.memory_efficient_attention +# [TODO] Unsure why Xformers is also breaking as well +xformers_attention = torch.compiler.disable(xformers.memory_efficient_attention) # Check TRL version from trl import __version__ as trl_version From 813cbdd220e8b7a994fd33373122f2189556e2bf Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 5 Nov 2024 01:30:34 -0800 Subject: [PATCH 153/209] Update _utils.py --- unsloth/models/_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index cdb28234c7..22aa4d3d80 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -395,13 +395,13 @@ def is_big_gpu(index): "max_autotune" : True, "shape_padding" : True, "trace.enabled" : UNSLOTH_COMPILE_DEBUG, - "triton.cudagraphs" : False, + "triton.cudagraphs" : True, } import accelerate def torch_compile_kwargs(*args, **kwargs): print("Unsloth: Enabled auto compiling") - return {"dynamic" : True, "fullgraph" : False, "options" : torch_compile_options,} + return {"dynamic" : True, "fullgraph" : True, "options" : torch_compile_options,} pass accelerate.utils.dataclasses.TorchDynamoPlugin.to_kwargs = torch_compile_kwargs From 34ce5d1dd9c44e52f098d0c2536b76a35ff7771f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 5 Nov 2024 01:36:11 -0800 Subject: [PATCH 154/209] Forward hook --- unsloth/models/_utils.py | 14 +++++++------- unsloth/models/llama.py | 1 + 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 22aa4d3d80..72994fa623 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -467,13 +467,13 @@ def prepare_model_for_kbit_training( pass # If use_reentrant = True which is the Pytorch default, we just make the input requires_grad. - if use_reentrant: - if hasattr(model, "enable_input_require_grads"): - model.enable_input_require_grads() - else: - def make_inputs_require_grad(module, input, output): - output.requires_grad_(True) - model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + # if use_reentrant: + # if hasattr(model, "enable_input_require_grads"): + # model.enable_input_require_grads() + # else: + # def make_inputs_require_grad(module, input, output): + # output.requires_grad_(True) + # model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) return model pass diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index c8bf5f8894..61857431b9 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -606,6 +606,7 @@ def LlamaModel_fast_forward( # Embed positions if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds.requires_grad_(True) # inputs_embeds = inputs_embeds.to(self.config.torch_dtype) torch_dtype = __DTYPE_MAP.get(self.config.torch_dtype, None) From f84cf4be32c4b13120f1e81e85ebbf44c0c19703 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 5 Nov 2024 01:41:04 -0800 Subject: [PATCH 155/209] Update _utils.py --- unsloth/models/_utils.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 72994fa623..cdb28234c7 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -395,13 +395,13 @@ def is_big_gpu(index): "max_autotune" : True, "shape_padding" : True, "trace.enabled" : UNSLOTH_COMPILE_DEBUG, - "triton.cudagraphs" : True, + "triton.cudagraphs" : False, } import accelerate def torch_compile_kwargs(*args, **kwargs): print("Unsloth: Enabled auto compiling") - return {"dynamic" : True, "fullgraph" : True, "options" : torch_compile_options,} + return {"dynamic" : True, "fullgraph" : False, "options" : torch_compile_options,} pass accelerate.utils.dataclasses.TorchDynamoPlugin.to_kwargs = torch_compile_kwargs @@ -467,13 +467,13 @@ def prepare_model_for_kbit_training( pass # If use_reentrant = True which is the Pytorch default, we just make the input requires_grad. - # if use_reentrant: - # if hasattr(model, "enable_input_require_grads"): - # model.enable_input_require_grads() - # else: - # def make_inputs_require_grad(module, input, output): - # output.requires_grad_(True) - # model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + if use_reentrant: + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) return model pass From 745814c4234df38087780ad0f6598321d903b1ba Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 5 Nov 2024 01:42:44 -0800 Subject: [PATCH 156/209] Update llama.py --- unsloth/models/llama.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 61857431b9..d379c56a84 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -390,7 +390,7 @@ def LlamaAttention_fast_forward( past_key_value = (K, V) if use_cache else None # Attention module - if (not HAS_FLASH_ATTENTION and attention_mask is None): + if False:#(not HAS_FLASH_ATTENTION and attention_mask is None): # Xformers memory efficient attention # Also has Flash Attention v2 dispatching Q = Q.transpose(1, 2) @@ -430,7 +430,7 @@ def LlamaAttention_fast_forward( Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous() # Needs (batch_size, n_heads, seq_len, head_dim) # is_casual and attention_mask must not be both set! - A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False) + A = scaled_dot_product_attention(Q, K, V, is_causal = True) # Go back to (batch_size, seq_len, n_heads, head_dim) A = A.transpose(1, 2).contiguous() pass @@ -606,7 +606,6 @@ def LlamaModel_fast_forward( # Embed positions if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - inputs_embeds.requires_grad_(True) # inputs_embeds = inputs_embeds.to(self.config.torch_dtype) torch_dtype = __DTYPE_MAP.get(self.config.torch_dtype, None) From ab9f8e17513dc266fcdb1756b0d9b4fe0ae12ce8 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 5 Nov 2024 01:52:27 -0800 Subject: [PATCH 157/209] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index cdb28234c7..5a7d0b005f 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -376,7 +376,7 @@ def _is_openai_available(): return False # ============================================= # Torch compile settings UNSLOTH_COMPILE_DEBUG = True #"UNSLOTH_COMPILE_DEBUG" in os.environ -UNSLOTH_COMPILE_MAXIMUM = False #"UNSLOTH_COMPILE_MAXIMUM" in os.environ +UNSLOTH_COMPILE_MAXIMUM = True #"UNSLOTH_COMPILE_MAXIMUM" in os.environ # Just remove max_autotune_gemm warning import functools @functools.lru_cache(None) From daa79099940bcea1964c4973facb1185c19c1430 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 5 Nov 2024 01:55:17 -0800 Subject: [PATCH 158/209] Update llama.py --- unsloth/models/llama.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index d379c56a84..3b8582581f 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -390,7 +390,7 @@ def LlamaAttention_fast_forward( past_key_value = (K, V) if use_cache else None # Attention module - if False:#(not HAS_FLASH_ATTENTION and attention_mask is None): + if (not HAS_FLASH_ATTENTION and attention_mask is None): # Xformers memory efficient attention # Also has Flash Attention v2 dispatching Q = Q.transpose(1, 2) @@ -430,7 +430,7 @@ def LlamaAttention_fast_forward( Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous() # Needs (batch_size, n_heads, seq_len, head_dim) # is_casual and attention_mask must not be both set! - A = scaled_dot_product_attention(Q, K, V, is_causal = True) + A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False) # Go back to (batch_size, seq_len, n_heads, head_dim) A = A.transpose(1, 2).contiguous() pass @@ -527,6 +527,7 @@ def LlamaDecoderLayer_fast_forward( } # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825 +@torch._disable_dynamo def LlamaModel_fast_forward( self, input_ids: torch.LongTensor, From 536a1a6e522d6efda1b41896d787fb624bce2688 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 5 Nov 2024 02:06:27 -0800 Subject: [PATCH 159/209] Update llama.py --- unsloth/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 3b8582581f..3c4d8f3b38 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -527,7 +527,6 @@ def LlamaDecoderLayer_fast_forward( } # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825 -@torch._disable_dynamo def LlamaModel_fast_forward( self, input_ids: torch.LongTensor, @@ -1029,6 +1028,7 @@ def _CausalLM_fast_forward( pass +@torch._disable_dynamo def PeftModelForCausalLM_fast_forward( self, input_ids=None, From 648ca59185637a540dea2c233f2d067bad83eb3c Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 5 Nov 2024 12:05:17 -0800 Subject: [PATCH 160/209] Update _utils.py --- unsloth/models/_utils.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 5a7d0b005f..94cf1b74e0 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -321,8 +321,7 @@ def _is_openai_available(): return False ) pass import xformers.ops.fmha as xformers -# [TODO] Unsure why Xformers is also breaking as well -xformers_attention = torch.compiler.disable(xformers.memory_efficient_attention) +xformers_attention = xformers.memory_efficient_attention # Check TRL version from trl import __version__ as trl_version @@ -375,8 +374,8 @@ def _is_openai_available(): return False # ============================================= # Torch compile settings -UNSLOTH_COMPILE_DEBUG = True #"UNSLOTH_COMPILE_DEBUG" in os.environ -UNSLOTH_COMPILE_MAXIMUM = True #"UNSLOTH_COMPILE_MAXIMUM" in os.environ +UNSLOTH_COMPILE_DEBUG = "UNSLOTH_COMPILE_DEBUG" in os.environ +UNSLOTH_COMPILE_MAXIMUM = "UNSLOTH_COMPILE_MAXIMUM" in os.environ # Just remove max_autotune_gemm warning import functools @functools.lru_cache(None) From 486d0d6b338d472deb233a3fc0cfb82b5edaeadd Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 5 Nov 2024 13:28:25 -0800 Subject: [PATCH 161/209] Update pyproject.toml --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8922cc7c8d..e0c5d93562 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ exclude = ["images*"] [project.optional-dependencies] huggingface = [ - "unsloth_zoo", + "unsloth_zoo>=2024.11.1", "packaging", "tyro", "transformers>=4.46.1", @@ -244,7 +244,7 @@ colab-ampere-torch220 = [ "flash-attn>=2.6.3", ] colab-new = [ - "unsloth_zoo", + "unsloth_zoo>=2024.11.1", "packaging", "tyro", "transformers>=4.46.1", From eb4da9d8cd1aeb524e4b83b168c8977834ba1ad5 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 5 Nov 2024 13:35:43 -0800 Subject: [PATCH 162/209] Update _utils.py --- unsloth/models/_utils.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 94cf1b74e0..b5c17434c7 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -88,8 +88,9 @@ # Disable some warnings which can get annoying warnings.filterwarnings(action = "ignore", category = UserWarning, module = "torch") warnings.filterwarnings(action = "ignore", category = UserWarning, module = "huggingface_hub") -warnings.filterwarnings(action = "ignore", category = UserWarning, module = "trl") warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "huggingface_hub") +warnings.filterwarnings(action = "ignore", category = UserWarning, module = "trl") +warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "trl") warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "xformers") warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "subprocess") warnings.filterwarnings(action = "ignore", category = UserWarning, module = "transformers") @@ -374,8 +375,7 @@ def _is_openai_available(): return False # ============================================= # Torch compile settings -UNSLOTH_COMPILE_DEBUG = "UNSLOTH_COMPILE_DEBUG" in os.environ -UNSLOTH_COMPILE_MAXIMUM = "UNSLOTH_COMPILE_MAXIMUM" in os.environ + # Just remove max_autotune_gemm warning import functools @functools.lru_cache(None) @@ -387,7 +387,14 @@ def is_big_gpu(index): return True import torch._inductor.utils torch._inductor.utils.is_big_gpu = is_big_gpu -patch_torch_compile(debug = UNSLOTH_COMPILE_DEBUG, O3 = UNSLOTH_COMPILE_MAXIMUM) + + +UNSLOTH_COMPILE_DEBUG = os.environ.get("UNSLOTH_COMPILE_DEBUG", "0") == "1" +UNSLOTH_COMPILE_MAXIMUM = os.environ.get("UNSLOTH_COMPILE_MAXIMUM", "0") == "1" +patch_torch_compile( + debug = UNSLOTH_COMPILE_DEBUG, + O3 = UNSLOTH_COMPILE_MAXIMUM, +) torch_compile_options = { "epilogue_fusion" : True, From da397f4f184039a4de1a2af902e213c73653f33b Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 5 Nov 2024 13:52:12 -0800 Subject: [PATCH 163/209] Update llama.py --- unsloth/models/llama.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 3c4d8f3b38..c4488127d9 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -390,7 +390,7 @@ def LlamaAttention_fast_forward( past_key_value = (K, V) if use_cache else None # Attention module - if (not HAS_FLASH_ATTENTION and attention_mask is None): + if False:#(not HAS_FLASH_ATTENTION and attention_mask is None): # Xformers memory efficient attention # Also has Flash Attention v2 dispatching Q = Q.transpose(1, 2) @@ -427,10 +427,10 @@ def LlamaAttention_fast_forward( pass # Must be contiguous or else results are False! # https://github.com/pytorch/pytorch/issues/112577 - Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous() + # Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous() # Needs (batch_size, n_heads, seq_len, head_dim) # is_casual and attention_mask must not be both set! - A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False) + A = scaled_dot_product_attention(Q, K, V, is_causal = True) # Go back to (batch_size, seq_len, n_heads, head_dim) A = A.transpose(1, 2).contiguous() pass @@ -1028,7 +1028,6 @@ def _CausalLM_fast_forward( pass -@torch._disable_dynamo def PeftModelForCausalLM_fast_forward( self, input_ids=None, From 70b65cf7a25dd6e73df704ffe040711826bbb2d1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 5 Nov 2024 14:40:57 -0800 Subject: [PATCH 164/209] CE Loss --- unsloth/kernels/cross_entropy_loss.py | 10 ++++++---- unsloth/models/_utils.py | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index f0193c74d8..13d90baafe 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -181,10 +181,10 @@ def _chunked_cross_entropy_forward( pass -@triton.heuristics({ - "DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]), - "DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]), -}) +# @triton.heuristics({ +# "DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]), +# "DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]), +# }) @triton.jit def _cross_entropy_backward( logits_ptr , @@ -345,6 +345,8 @@ def backward(ctx, dlosses): n_rows, vocab_size = logits.shape BLOCK_SIZE : int = 4096 + div : int + mod : int div, mod = divmod(vocab_size, BLOCK_SIZE) n_blocks : int = div + (mod != 0) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index b5c17434c7..bb004d588d 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -393,7 +393,7 @@ def is_big_gpu(index): UNSLOTH_COMPILE_MAXIMUM = os.environ.get("UNSLOTH_COMPILE_MAXIMUM", "0") == "1" patch_torch_compile( debug = UNSLOTH_COMPILE_DEBUG, - O3 = UNSLOTH_COMPILE_MAXIMUM, + O3 = UNSLOTH_COMPILE_MAXIMUM, ) torch_compile_options = { From aeec57e110994142ac3d64bc4a74e077c2055565 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 5 Nov 2024 14:43:49 -0800 Subject: [PATCH 165/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 13d90baafe..e5136f8412 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -181,10 +181,10 @@ def _chunked_cross_entropy_forward( pass -# @triton.heuristics({ -# "DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]), -# "DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]), -# }) +@triton.heuristics({ + "DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]), + "DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]), +}) @triton.jit def _cross_entropy_backward( logits_ptr , @@ -337,6 +337,8 @@ def forward(ctx, logits, labels, logit_softcapping : float = 0, logit_scaling : return losses pass + + @torch.compiler.disable @staticmethod def backward(ctx, dlosses): logits, logsumexp, labels = ctx.saved_tensors From fb393fc9baa84794b95becad0a030e8e16b3d35a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 5 Nov 2024 14:54:07 -0800 Subject: [PATCH 166/209] Update _utils.py --- unsloth/models/_utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index bb004d588d..95ea381f00 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -389,11 +389,13 @@ def is_big_gpu(index): torch._inductor.utils.is_big_gpu = is_big_gpu -UNSLOTH_COMPILE_DEBUG = os.environ.get("UNSLOTH_COMPILE_DEBUG", "0") == "1" -UNSLOTH_COMPILE_MAXIMUM = os.environ.get("UNSLOTH_COMPILE_MAXIMUM", "0") == "1" +UNSLOTH_COMPILE_DEBUG = os.environ.get("UNSLOTH_COMPILE_DEBUG", "0") == "1" +UNSLOTH_COMPILE_MAXIMUM = os.environ.get("UNSLOTH_COMPILE_MAXIMUM", "0") == "1" +UNSLOTH_COMPILE_IGNORE_ERRORS = os.environ.get("UNSLOTH_COMPILE_IGNORE_ERRORS", "0") == "1" patch_torch_compile( - debug = UNSLOTH_COMPILE_DEBUG, - O3 = UNSLOTH_COMPILE_MAXIMUM, + debug = UNSLOTH_COMPILE_DEBUG, + O3 = UNSLOTH_COMPILE_MAXIMUM, + ignore_errors = UNSLOTH_COMPILE_IGNORE_ERRORS, ) torch_compile_options = { From cab1e722d68d3255c4a46c5790411047e34e2993 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 5 Nov 2024 20:29:30 -0800 Subject: [PATCH 167/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 28 +++++++++++++-------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index e5136f8412..bb8f002ac8 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -104,18 +104,17 @@ def _cross_entropy_forward( }) @triton.jit def _chunked_cross_entropy_forward( - logits_ptr , - logits_row_stride , - loss_ptr , - logsumexp_ptr , - labels_ptr , - VOCAB_SIZE , - N_CHUNKS , - BLOCK_SIZE : tl.constexpr, - DO_SOFTCAPPING , - SOFTCAP , - DO_LOGIT_SCALING , - LOGIT_SCALE , + logits_ptr, logits_row_stride, + loss_ptr, + logsumexp_ptr, + labels_ptr, + VOCAB_SIZE : tl.constexpr, + N_CHUNKS : tl.constexpr, + BLOCK_SIZE : tl.constexpr, + DO_SOFTCAPPING : tl.constexpr, + SOFTCAP : tl.constexpr, + DO_LOGIT_SCALING: tl.constexpr, + LOGIT_SCALE : tl.constexpr, ): """ 256K vocab divided in 4 chunks @@ -143,7 +142,7 @@ def _chunked_cross_entropy_forward( """ row_idx = tl.program_id(0) chunk_idx = tl.program_id(1) - logits_ptr += row_idx * tl.cast(logits_row_stride, tl.int64) + logits_ptr += row_idx * logits_row_stride.to(tl.int64) loss_ptr += row_idx logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx labels_ptr += row_idx @@ -157,7 +156,7 @@ def _chunked_cross_entropy_forward( # Go logit scaling for Cohere: t * x if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits # Do logit softcapping for Gemma 2: t * tanh(1/t * x) - if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits.to(tl.float32) / SOFTCAP).to(logits.dtype) + if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP) logits = logits.to(tl.float32) c = tl.max(logits, 0) @@ -338,7 +337,6 @@ def forward(ctx, logits, labels, logit_softcapping : float = 0, logit_scaling : pass - @torch.compiler.disable @staticmethod def backward(ctx, dlosses): logits, logsumexp, labels = ctx.saved_tensors From 51fea97dec9e94350389d3b71a5932c0ad6c1564 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 5 Nov 2024 20:58:15 -0800 Subject: [PATCH 168/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 28 +++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index bb8f002ac8..64825bac5c 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -104,17 +104,18 @@ def _cross_entropy_forward( }) @triton.jit def _chunked_cross_entropy_forward( - logits_ptr, logits_row_stride, - loss_ptr, - logsumexp_ptr, - labels_ptr, - VOCAB_SIZE : tl.constexpr, - N_CHUNKS : tl.constexpr, - BLOCK_SIZE : tl.constexpr, - DO_SOFTCAPPING : tl.constexpr, - SOFTCAP : tl.constexpr, - DO_LOGIT_SCALING: tl.constexpr, - LOGIT_SCALE : tl.constexpr, + logits_ptr , + logits_row_stride , + loss_ptr , + logsumexp_ptr , + labels_ptr , + VOCAB_SIZE , + N_CHUNKS , + BLOCK_SIZE : tl.constexpr, + DO_SOFTCAPPING , + SOFTCAP , + DO_LOGIT_SCALING , + LOGIT_SCALE , ): """ 256K vocab divided in 4 chunks @@ -142,7 +143,7 @@ def _chunked_cross_entropy_forward( """ row_idx = tl.program_id(0) chunk_idx = tl.program_id(1) - logits_ptr += row_idx * logits_row_stride.to(tl.int64) + logits_ptr += row_idx * tl.cast(logits_row_stride, tl.int64) loss_ptr += row_idx logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx labels_ptr += row_idx @@ -151,14 +152,13 @@ def _chunked_cross_entropy_forward( mask = col_offsets < VOCAB_SIZE label_idx = tl.load(labels_ptr).to(tl.int32) - logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")) + logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32) # Go logit scaling for Cohere: t * x if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits # Do logit softcapping for Gemma 2: t * tanh(1/t * x) if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP) - logits = logits.to(tl.float32) c = tl.max(logits, 0) logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0)) From 58e541bd910e28ee617bc6dd044dcbfee400e8be Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 5 Nov 2024 21:01:33 -0800 Subject: [PATCH 169/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 64825bac5c..0c07035097 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -73,14 +73,13 @@ def _cross_entropy_forward( mask = col_offsets < VOCAB_SIZE label_idx = tl.load(labels_ptr).to(tl.int32) - logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")) + logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32) # Go logit scaling for Cohere: t * x if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits # Do logit softcapping for Gemma 2: t * tanh(1/t * x) - if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits.to(tl.float32) / SOFTCAP).to(logits.dtype) - - logits = logits.to(tl.float32) + if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP) + c = tl.max(logits, 0) logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0)) @@ -228,7 +227,7 @@ def _cross_entropy_backward( else: dloss = 0.0 - x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")) + x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32) # Do logit scaling for Cohere if DO_LOGIT_SCALING: @@ -240,12 +239,12 @@ def _cross_entropy_backward( partial = x if DO_SOFTCAPPING: # d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x) - partial = triton_tanh(x.to(tl.float32) / SOFTCAP).to(x.dtype) + partial = triton_tanh(x / SOFTCAP) x = SOFTCAP * partial pass logsumexp = tl.load(logsumexp_ptr + row_idx) - y = tl.exp(x.to(tl.float32) - logsumexp) + y = tl.exp(x - logsumexp) y = tl.where( col_offsets == label_idx, y - 1.0, # exp(x - logsumexp) - 1 From ef2c56f3da1508bb2753f164d35b853e67a81c4f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 5 Nov 2024 21:07:02 -0800 Subject: [PATCH 170/209] Update llama.py --- unsloth/models/llama.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index c4488127d9..3c4d8f3b38 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -390,7 +390,7 @@ def LlamaAttention_fast_forward( past_key_value = (K, V) if use_cache else None # Attention module - if False:#(not HAS_FLASH_ATTENTION and attention_mask is None): + if (not HAS_FLASH_ATTENTION and attention_mask is None): # Xformers memory efficient attention # Also has Flash Attention v2 dispatching Q = Q.transpose(1, 2) @@ -427,10 +427,10 @@ def LlamaAttention_fast_forward( pass # Must be contiguous or else results are False! # https://github.com/pytorch/pytorch/issues/112577 - # Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous() + Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous() # Needs (batch_size, n_heads, seq_len, head_dim) # is_casual and attention_mask must not be both set! - A = scaled_dot_product_attention(Q, K, V, is_causal = True) + A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False) # Go back to (batch_size, seq_len, n_heads, head_dim) A = A.transpose(1, 2).contiguous() pass @@ -1028,6 +1028,7 @@ def _CausalLM_fast_forward( pass +@torch._disable_dynamo def PeftModelForCausalLM_fast_forward( self, input_ids=None, From 13d7412bbfa82b4c2058da1b5a5f452dc868aa60 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 5 Nov 2024 21:24:56 -0800 Subject: [PATCH 171/209] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 94cf1b74e0..e1a90993ed 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2024.11.1" +__version__ = "2024.11.3" __all__ = [ "prepare_model_for_kbit_training", From 5a7eaf8a60d6bf187304f10eb0ebdc5f5d2814e2 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 5 Nov 2024 21:44:38 -0800 Subject: [PATCH 172/209] Update _utils.py --- unsloth/models/_utils.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index e1a90993ed..b105ea7494 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -88,8 +88,9 @@ # Disable some warnings which can get annoying warnings.filterwarnings(action = "ignore", category = UserWarning, module = "torch") warnings.filterwarnings(action = "ignore", category = UserWarning, module = "huggingface_hub") -warnings.filterwarnings(action = "ignore", category = UserWarning, module = "trl") warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "huggingface_hub") +warnings.filterwarnings(action = "ignore", category = UserWarning, module = "trl") +warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "trl") warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "xformers") warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "subprocess") warnings.filterwarnings(action = "ignore", category = UserWarning, module = "transformers") @@ -374,8 +375,9 @@ def _is_openai_available(): return False # ============================================= # Torch compile settings -UNSLOTH_COMPILE_DEBUG = "UNSLOTH_COMPILE_DEBUG" in os.environ -UNSLOTH_COMPILE_MAXIMUM = "UNSLOTH_COMPILE_MAXIMUM" in os.environ +UNSLOTH_COMPILE_DEBUG = os.environ.get("UNSLOTH_COMPILE_DEBUG", "0") == "1" +UNSLOTH_COMPILE_MAXIMUM = os.environ.get("UNSLOTH_COMPILE_MAXIMUM", "0") == "1" +UNSLOTH_COMPILE_IGNORE_ERRORS = os.environ.get("UNSLOTH_COMPILE_IGNORE_ERRORS", "0") == "1" # Just remove max_autotune_gemm warning import functools @functools.lru_cache(None) @@ -387,7 +389,11 @@ def is_big_gpu(index): return True import torch._inductor.utils torch._inductor.utils.is_big_gpu = is_big_gpu -patch_torch_compile(debug = UNSLOTH_COMPILE_DEBUG, O3 = UNSLOTH_COMPILE_MAXIMUM) +patch_torch_compile( + debug = UNSLOTH_COMPILE_DEBUG, + O3 = UNSLOTH_COMPILE_MAXIMUM, + ignore_errors = UNSLOTH_COMPILE_IGNORE_ERRORS, +) torch_compile_options = { "epilogue_fusion" : True, From d2186ed3c83d7aab612cad55a3c23201a26e16f1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Tue, 5 Nov 2024 22:48:36 -0800 Subject: [PATCH 173/209] Update _utils.py --- unsloth/models/_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index b105ea7494..4a9b8847ad 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -396,8 +396,8 @@ def is_big_gpu(index): ) torch_compile_options = { - "epilogue_fusion" : True, - "max_autotune" : True, + "epilogue_fusion" : False, + "max_autotune" : False, "shape_padding" : True, "trace.enabled" : UNSLOTH_COMPILE_DEBUG, "triton.cudagraphs" : False, From 6434447f25e63cfdc5afa5bc7dbfe742aed1c2cc Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 6 Nov 2024 00:18:04 -0800 Subject: [PATCH 174/209] Update _utils.py --- unsloth/models/_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 4a9b8847ad..cdc2cd45fd 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -377,7 +377,7 @@ def _is_openai_available(): return False # Torch compile settings UNSLOTH_COMPILE_DEBUG = os.environ.get("UNSLOTH_COMPILE_DEBUG", "0") == "1" UNSLOTH_COMPILE_MAXIMUM = os.environ.get("UNSLOTH_COMPILE_MAXIMUM", "0") == "1" -UNSLOTH_COMPILE_IGNORE_ERRORS = os.environ.get("UNSLOTH_COMPILE_IGNORE_ERRORS", "0") == "1" +UNSLOTH_COMPILE_IGNORE_ERRORS = os.environ.get("UNSLOTH_COMPILE_IGNORE_ERRORS", "1") == "1" # Just remove max_autotune_gemm warning import functools @functools.lru_cache(None) @@ -396,8 +396,8 @@ def is_big_gpu(index): ) torch_compile_options = { - "epilogue_fusion" : False, - "max_autotune" : False, + "epilogue_fusion" : True, + "max_autotune" : True, "shape_padding" : True, "trace.enabled" : UNSLOTH_COMPILE_DEBUG, "triton.cudagraphs" : False, From 67611e624670cb1163283c889734a996485fc1e3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 6 Nov 2024 02:10:37 -0800 Subject: [PATCH 175/209] Update _utils.py --- unsloth/models/_utils.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index cdc2cd45fd..903093e60f 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -69,7 +69,6 @@ patch_compiling_bitsandbytes, patch_layernorm, patch_torch_compile, - patch_regional_compilation, patch_model_and_tokenizer, ) from unsloth_zoo.gradient_checkpointing import ( @@ -414,6 +413,26 @@ def torch_compile_kwargs(*args, **kwargs): accelerate.accelerator.TorchDynamoPlugin.to_kwargs = torch_compile_kwargs del accelerate +def patch_regional_compilation(): + # Regional torch 2.5 Recompilation - weirdly very slow?? + if torch.nn.ModuleList.__name__ == "UnslothModuleList": return + # Only works for torch 2.5 + if Version(torch.__version__) < Version("2.5.0"): return + + old_module_list = torch.nn.ModuleList + os.environ["UNSLOTH_PATCHED"] = "1" + + def UnslothModuleList(*args, **kwargs): + if len(args) == 1 and len(kwargs) == 0 and type(args[0]) is list: + args = [old_module_list([torch.compile(x, dynamic = True, options = torch_compile_options, fullgraph = False) for x in args[0]])] + return old_module_list(*args, **kwargs) + pass + UnslothModuleList.__doc__ = old_module_list.__doc__ + + torch.nn.ModuleList = UnslothModuleList + return +pass + # ============================================= def prepare_model_for_kbit_training( From f24aef5cbc50467493d906e61ce95d3159ee957f Mon Sep 17 00:00:00 2001 From: Edd <68678137+Erland366@users.noreply.github.com> Date: Thu, 7 Nov 2024 00:16:02 +0400 Subject: [PATCH 176/209] Fix: cast logits to float32 in cross_entropy_forward to prevent errors (#1254) * Fix: cast logits to float32 in cross_entropy_forward to prevent errors * Update cross_entropy_loss.py --------- Co-authored-by: Daniel Han --- unsloth/kernels/cross_entropy_loss.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 0c07035097..cc3dbb1d87 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -84,12 +84,12 @@ def _cross_entropy_forward( logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0)) if label_idx != -100: - x = tl.load(logits_ptr + label_idx) + x = tl.load(logits_ptr + label_idx).to(tl.float32) # Go logit scaling for Cohere: t * x if DO_LOGIT_SCALING: x = LOGIT_SCALE * x # Do logit softcapping for Gemma 2: t * tanh(1/t * x) if DO_SOFTCAPPING: x = SOFTCAP * triton_tanh(x / SOFTCAP) - loss = logsumexp - x.to(tl.float32) + loss = logsumexp - x else: loss = 0.0 tl.store(logsumexp_ptr, logsumexp) @@ -170,7 +170,7 @@ def _chunked_cross_entropy_forward( if DO_LOGIT_SCALING: x = LOGIT_SCALE * x # Do logit softcapping for Gemma 2: t * tanh(1/t * x) if DO_SOFTCAPPING: x = SOFTCAP * triton_tanh(x / SOFTCAP) - loss = -1.0 * x.to(tl.float32) + loss = -1.0 * x else: loss = 0.0 tl.store(loss_ptr, loss) From 3d906e637847b689544dc804305ffc8365e17af3 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Thu, 7 Nov 2024 01:52:08 +0530 Subject: [PATCH 177/209] Throw error when inferencing longer than max_popsition_embeddings (#1236) * Throw error when inferencing longer than max_popsition_embeddings without rope scaling * Update llama.py --------- Co-authored-by: Daniel Han --- unsloth/models/llama.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 3c4d8f3b38..7f07bea4c5 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1376,6 +1376,15 @@ def _wrap_fast_inference(generate, device_type, dtype, model): @torch.inference_mode def _fast_generate(*args, **kwargs): + if hasattr(model, "config") and hasattr(model.config, "max_position_embeddings"): + if "input_ids" in kwargs and kwargs["input_ids"] is not None and "max_new_tokens" in kwargs: + if kwargs["input_ids"].shape[-1] + kwargs["max_new_tokens"] > model.config.max_position_embeddings: + raise ValueError( + f'Unsloth: input length {kwargs["input_ids"].shape[-1]} + max_new_tokens {kwargs["max_new_tokens"]} exceeds the maximum sequence length of {model.config.max_position_embeddings}!\n'\ + 'You will need to do long context extension by increasing the `max_seq_length` in `FastLanguageModel.from_pretrained`.' + ) + pass + # Set a flag for generation! internal_model = model while hasattr(internal_model, "model"): From de1049bcc524e6c3eff5874e1c79aad603cd9f97 Mon Sep 17 00:00:00 2001 From: Edwin Fennell Date: Wed, 6 Nov 2024 20:23:09 +0000 Subject: [PATCH 178/209] CLI now handles user input strings for dtype correctly (#1235) Co-authored-by: root --- unsloth/models/loader.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index db7259b1d9..8dcdebab12 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -44,6 +44,25 @@ from .gemma2 import FastGemma2Model pass +def get_dtype_from_input( + dtype +): + '''Converts user-defined dtype input string to a usable dtype''' + TORCH_FLOAT16_SYNONYMS = {"torch.float16"} + TORCH_BFLOAT16_SYNONYMS = {"torch.bfloat16"} + TORCH_FLOAT32_SYNONYMS = {"torch.float32"} + if dtype in TORCH_FLOAT16_SYNONYMS: + return torch.float16 + if dtype in TORCH_BFLOAT16_SYNONYMS: + return torch.bfloat16 + if dtype in TORCH_FLOAT32_SYNONYMS: + return torch.float32 + if dtype != "None": + print(f"--------------------------------------------------\n"\ + f"User-specified dtype not recognised. Defaulting to dtype = None\n"\ + f"--------------------------------------------------") + return None + def __get_model_name( model_name, @@ -332,7 +351,7 @@ def from_pretrained( model, tokenizer = dispatch_model.from_pretrained( model_name = model_name, max_seq_length = max_seq_length, - dtype = dtype, + dtype = get_dtype_from_input(dtype), load_in_4bit = load_in_4bit, token = token, device_map = device_map, From be72975d8aba9c7cbbdc906d87a41f5ac58d8cea Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 6 Nov 2024 13:09:29 -0800 Subject: [PATCH 179/209] Update flex_attention.py --- unsloth/kernels/flex_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/kernels/flex_attention.py b/unsloth/kernels/flex_attention.py index 08426b69e0..7deb9a96d3 100644 --- a/unsloth/kernels/flex_attention.py +++ b/unsloth/kernels/flex_attention.py @@ -20,7 +20,7 @@ "epilogue_fusion" : True, "max_autotune" : True, "shape_padding" : True, - "trace.enabled" : False, # Output Triton kernel outputs! + "trace.enabled" : os.environ.get("UNSLOTH_COMPILE_DEBUG", "0") == "1", "triton.cudagraphs" : False, } From 05170cd81a522d3a6f9523c60a2a6e3ea31da5cb Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 6 Nov 2024 14:05:52 -0800 Subject: [PATCH 180/209] Update _utils.py --- unsloth/models/_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 903093e60f..d645235f02 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -388,11 +388,11 @@ def is_big_gpu(index): return True import torch._inductor.utils torch._inductor.utils.is_big_gpu = is_big_gpu -patch_torch_compile( - debug = UNSLOTH_COMPILE_DEBUG, - O3 = UNSLOTH_COMPILE_MAXIMUM, - ignore_errors = UNSLOTH_COMPILE_IGNORE_ERRORS, -) +# patch_torch_compile( +# debug = UNSLOTH_COMPILE_DEBUG, +# O3 = UNSLOTH_COMPILE_MAXIMUM, +# ignore_errors = UNSLOTH_COMPILE_IGNORE_ERRORS, +# ) torch_compile_options = { "epilogue_fusion" : True, From 7e0877d383a1937039902b65d48fdc16ac6c8c0a Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 6 Nov 2024 14:49:37 -0800 Subject: [PATCH 181/209] Update _utils.py --- unsloth/models/_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index d645235f02..903093e60f 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -388,11 +388,11 @@ def is_big_gpu(index): return True import torch._inductor.utils torch._inductor.utils.is_big_gpu = is_big_gpu -# patch_torch_compile( -# debug = UNSLOTH_COMPILE_DEBUG, -# O3 = UNSLOTH_COMPILE_MAXIMUM, -# ignore_errors = UNSLOTH_COMPILE_IGNORE_ERRORS, -# ) +patch_torch_compile( + debug = UNSLOTH_COMPILE_DEBUG, + O3 = UNSLOTH_COMPILE_MAXIMUM, + ignore_errors = UNSLOTH_COMPILE_IGNORE_ERRORS, +) torch_compile_options = { "epilogue_fusion" : True, From 6b5c5993947fac6fcd62041a82587a87a41e2176 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 6 Nov 2024 14:51:52 -0800 Subject: [PATCH 182/209] Update flex_attention.py --- unsloth/kernels/flex_attention.py | 1 + 1 file changed, 1 insertion(+) diff --git a/unsloth/kernels/flex_attention.py b/unsloth/kernels/flex_attention.py index 7deb9a96d3..887ffca1b7 100644 --- a/unsloth/kernels/flex_attention.py +++ b/unsloth/kernels/flex_attention.py @@ -15,6 +15,7 @@ import torch from functools import lru_cache from transformers.models.llama.modeling_llama import logger +import os torch_compile_options = { "epilogue_fusion" : True, From 1ba9f2ed87133f9fb3c7c5ec69e3d0f9d78fd949 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 6 Nov 2024 14:54:53 -0800 Subject: [PATCH 183/209] Update flex_attention.py --- unsloth/kernels/flex_attention.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/unsloth/kernels/flex_attention.py b/unsloth/kernels/flex_attention.py index 887ffca1b7..dfd48504d5 100644 --- a/unsloth/kernels/flex_attention.py +++ b/unsloth/kernels/flex_attention.py @@ -40,6 +40,14 @@ if not HAS_FLEX_ATTENTION: + # Below fails on compiled_autograd, so disable it + try: + old_compiled_autograd = torch._dynamo.config.compiled_autograd + torch._dynamo.config.compiled_autograd = False + except: + old_compiled_autograd = False + pass + # Logit softcapping @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options) def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): @@ -74,6 +82,13 @@ def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): return A pass + # Return compiled_autograd back + try: + torch._dynamo.config.compiled_autograd = old_compiled_autograd + except: + pass + pass + create_flex_attention_causal_mask = None create_flex_attention_sliding_window_mask = None else: From da61c4dde9247fcf4623c387fac4ee5612163367 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 6 Nov 2024 15:05:10 -0800 Subject: [PATCH 184/209] Update loader.py --- unsloth/models/loader.py | 35 +++++++++++++++++------------------ 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 8dcdebab12..cafb1282f7 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -44,24 +44,23 @@ from .gemma2 import FastGemma2Model pass -def get_dtype_from_input( - dtype -): - '''Converts user-defined dtype input string to a usable dtype''' - TORCH_FLOAT16_SYNONYMS = {"torch.float16"} - TORCH_BFLOAT16_SYNONYMS = {"torch.bfloat16"} - TORCH_FLOAT32_SYNONYMS = {"torch.float32"} - if dtype in TORCH_FLOAT16_SYNONYMS: + +def _get_dtype(dtype): + __DTYPE_MAP = { + "float32": torch.float32, + torch.float32: torch.float32, + "float16": torch.float16, + torch.float16: torch.float16, + "bfloat16": torch.bfloat16, + torch.bfloat16: torch.bfloat16, + } + if dtype in __DTYPE_MAP: + return __DTYPE_MAP[dtype] + else: + print(f"Unsloth: {dtype} is not recognized, so we'll default to torch.float16") return torch.float16 - if dtype in TORCH_BFLOAT16_SYNONYMS: - return torch.bfloat16 - if dtype in TORCH_FLOAT32_SYNONYMS: - return torch.float32 - if dtype != "None": - print(f"--------------------------------------------------\n"\ - f"User-specified dtype not recognised. Defaulting to dtype = None\n"\ - f"--------------------------------------------------") - return None + pass +pass def __get_model_name( @@ -351,7 +350,7 @@ def from_pretrained( model, tokenizer = dispatch_model.from_pretrained( model_name = model_name, max_seq_length = max_seq_length, - dtype = get_dtype_from_input(dtype), + dtype = _get_dtype(dtype), load_in_4bit = load_in_4bit, token = token, device_map = device_map, From 3316ee2282869ae8f85b571b420d760d665f016f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 6 Nov 2024 15:13:03 -0800 Subject: [PATCH 185/209] Update loader.py --- unsloth/models/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index cafb1282f7..4566302ed0 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -43,7 +43,7 @@ if SUPPORTS_GEMMA2: from .gemma2 import FastGemma2Model pass - +import torch def _get_dtype(dtype): __DTYPE_MAP = { From 501ca842970d17bfc7d77d007fa18d80d86381f4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 6 Nov 2024 15:39:32 -0800 Subject: [PATCH 186/209] Update flex_attention.py --- unsloth/kernels/flex_attention.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/unsloth/kernels/flex_attention.py b/unsloth/kernels/flex_attention.py index dfd48504d5..678574928c 100644 --- a/unsloth/kernels/flex_attention.py +++ b/unsloth/kernels/flex_attention.py @@ -42,14 +42,15 @@ # Below fails on compiled_autograd, so disable it try: - old_compiled_autograd = torch._dynamo.config.compiled_autograd - torch._dynamo.config.compiled_autograd = False + disable_compiled_autograd = torch._dynamo.compiled_autograd.disable except: - old_compiled_autograd = False + disable_compiled_autograd = lambda *args, **kwargs: *args, **kwargs pass # Logit softcapping - @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options) + @disable_compiled_autograd( + torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options) + ) def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): n_heads = self.num_heads head_dim = self.head_dim @@ -82,13 +83,6 @@ def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): return A pass - # Return compiled_autograd back - try: - torch._dynamo.config.compiled_autograd = old_compiled_autograd - except: - pass - pass - create_flex_attention_causal_mask = None create_flex_attention_sliding_window_mask = None else: From ce621b7af32b44ee2c6bbce5f9d302b6e0bc677e Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 6 Nov 2024 15:56:26 -0800 Subject: [PATCH 187/209] Update flex_attention.py --- unsloth/kernels/flex_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/kernels/flex_attention.py b/unsloth/kernels/flex_attention.py index 678574928c..99e46a904c 100644 --- a/unsloth/kernels/flex_attention.py +++ b/unsloth/kernels/flex_attention.py @@ -44,7 +44,7 @@ try: disable_compiled_autograd = torch._dynamo.compiled_autograd.disable except: - disable_compiled_autograd = lambda *args, **kwargs: *args, **kwargs + disable_compiled_autograd = lambda f: f pass # Logit softcapping From 4b01ff1724f03bed74bfddf8963850e7c84d9d3f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 6 Nov 2024 15:56:52 -0800 Subject: [PATCH 188/209] Update flex_attention.py --- unsloth/kernels/flex_attention.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/unsloth/kernels/flex_attention.py b/unsloth/kernels/flex_attention.py index 99e46a904c..1342bfafc2 100644 --- a/unsloth/kernels/flex_attention.py +++ b/unsloth/kernels/flex_attention.py @@ -42,15 +42,13 @@ # Below fails on compiled_autograd, so disable it try: - disable_compiled_autograd = torch._dynamo.compiled_autograd.disable + disable_compile = torch._dynamo.compiled_autograd.disable except: - disable_compiled_autograd = lambda f: f + disable_compile = lambda f: f pass # Logit softcapping - @disable_compiled_autograd( - torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options) - ) + @disable_compile(torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)) def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): n_heads = self.num_heads head_dim = self.head_dim From ef5052a8bc97bbb6ca6e82fc8f7ede47ad37ca2d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 6 Nov 2024 16:59:15 -0800 Subject: [PATCH 189/209] Update flex_attention.py --- unsloth/kernels/flex_attention.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/unsloth/kernels/flex_attention.py b/unsloth/kernels/flex_attention.py index 1342bfafc2..887ffca1b7 100644 --- a/unsloth/kernels/flex_attention.py +++ b/unsloth/kernels/flex_attention.py @@ -40,15 +40,8 @@ if not HAS_FLEX_ATTENTION: - # Below fails on compiled_autograd, so disable it - try: - disable_compile = torch._dynamo.compiled_autograd.disable - except: - disable_compile = lambda f: f - pass - # Logit softcapping - @disable_compile(torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)) + @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options) def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): n_heads = self.num_heads head_dim = self.head_dim From 52bca32ce9b2162aa60061d54a4f6c9a78a145f4 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 6 Nov 2024 17:14:56 -0800 Subject: [PATCH 190/209] Update _utils.py --- unsloth/models/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 903093e60f..a6cd13d251 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2024.11.3" +__version__ = "2024.11.4" __all__ = [ "prepare_model_for_kbit_training", From 8b3e9c2ff0b3e5c6fb11d2913023a1d9fa069324 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 6 Nov 2024 21:07:50 -0800 Subject: [PATCH 191/209] Update cross_entropy_loss.py --- unsloth/kernels/cross_entropy_loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index cc3dbb1d87..f82defd405 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -400,6 +400,6 @@ def fast_cross_entropy_loss( pass # Patch CE Losses in transformers -def patch_loss_functions(): - _patch_loss_functions(fast_cross_entropy_loss) +def patch_loss_functions(torch_compile = True): + _patch_loss_functions(fast_cross_entropy_loss, torch_compile = torch_compile) pass From 3a1e7ef8299f3c96fa6e8de11fd0772af3cbc83f Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Thu, 7 Nov 2024 01:11:45 -0800 Subject: [PATCH 192/209] Update _utils.py --- unsloth/models/_utils.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 5fff96642d..cb9ae48a86 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -104,7 +104,7 @@ # Ignore logging messages class HideLoggingMessage(logging.Filter): def __init__(self, text): self.text = text - def filter(self, x): return not x.getMessage().startswith(self.text) + def filter(self, x): return not (self.text in x.getMessage()) pass # The speedups for torchdynamo mostly come wih GPU Ampere or higher and which is not detected here. @@ -112,6 +112,14 @@ def filter(self, x): return not x.getMessage().startswith(self.text) transformers_training_args_logger.addFilter(HideLoggingMessage("The speedups")) del transformers_training_args_logger +# Using the default loss: `ForCausalLMLoss`. +try: + from transformers.modeling_utils import logger as transformers_modeling_utils_logger + transformers_modeling_utils_logger.addFilter(HideLoggingMessage("ForCausalLMLoss")) + del transformers_modeling_utils_logger +except: + pass + # ============================================= # ============================================= From f1ec165096f7d9f54ed988b546d60dec9b443dab Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 9 Nov 2024 17:01:33 -0800 Subject: [PATCH 193/209] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index c639dbf1a0..6d0ee548c1 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -1001,13 +1001,14 @@ def patch_sft_trainer_tokenizer(): # Also DPO weirdly tokenizes non numeric columns? Delete them! check_text += \ "\n"\ - "column_names = set(self.train_dataset.column_names)\n"\ - "check = ['chosen', 'rejected', 'prompt', 'chosen_input_ids', 'chosen_attention_mask',\n"\ - " 'chosen_labels', 'rejected_input_ids', 'rejected_attention_mask', 'rejected_labels',\n"\ - " 'prompt_input_ids', 'prompt_attention_mask']\n"\ - "if all(x in column_names for x in check):\n"\ - " self.train_dataset = self.train_dataset.remove_columns(['chosen', 'rejected', 'prompt'])\n"\ - "del check, column_names\n"\ + "if hasattr(self.train_dataset, 'column_names''):\n" + " column_names = set(self.train_dataset.column_names)\n"\ + " check = ['chosen', 'rejected', 'prompt', 'chosen_input_ids', 'chosen_attention_mask',\n"\ + " 'chosen_labels', 'rejected_input_ids', 'rejected_attention_mask', 'rejected_labels',\n"\ + " 'prompt_input_ids', 'prompt_attention_mask']\n"\ + " if all(x in column_names for x in check):\n"\ + " self.train_dataset = self.train_dataset.remove_columns(['chosen', 'rejected', 'prompt'])\n"\ + " del check, column_names\n"\ "\n" check_text = check_text.split("\n") From a4e970553164acca5fad28287c25e14af30891f9 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 9 Nov 2024 17:34:47 -0800 Subject: [PATCH 194/209] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 6d0ee548c1..7967676d80 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -1001,7 +1001,7 @@ def patch_sft_trainer_tokenizer(): # Also DPO weirdly tokenizes non numeric columns? Delete them! check_text += \ "\n"\ - "if hasattr(self.train_dataset, 'column_names''):\n" + "if hasattr(self.train_dataset, 'column_names'):\n" " column_names = set(self.train_dataset.column_names)\n"\ " check = ['chosen', 'rejected', 'prompt', 'chosen_input_ids', 'chosen_attention_mask',\n"\ " 'chosen_labels', 'rejected_input_ids', 'rejected_attention_mask', 'rejected_labels',\n"\ From 92c6a2784043df578e7ba7c4ff0e92b5866da09d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 9 Nov 2024 17:37:11 -0800 Subject: [PATCH 195/209] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 7967676d80..a9b635203e 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -1015,6 +1015,8 @@ def patch_sft_trainer_tokenizer(): check_text = "\n".join(" "*where + x for x in check_text) function = function.replace(replacer, check_text + replacer) + print(function) + raise exec(function, globals()) exec(f"trl.trainer.{path_to_trainer}.{function_name} = {function_name}", globals()) From 673f541788bdd858cf4ddb98edd3fa52fe4892c1 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Sat, 9 Nov 2024 17:40:32 -0800 Subject: [PATCH 196/209] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index a9b635203e..57829e621d 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -1001,7 +1001,7 @@ def patch_sft_trainer_tokenizer(): # Also DPO weirdly tokenizes non numeric columns? Delete them! check_text += \ "\n"\ - "if hasattr(self.train_dataset, 'column_names'):\n" + "if hasattr(self.train_dataset, 'column_names'):\n"\ " column_names = set(self.train_dataset.column_names)\n"\ " check = ['chosen', 'rejected', 'prompt', 'chosen_input_ids', 'chosen_attention_mask',\n"\ " 'chosen_labels', 'rejected_input_ids', 'rejected_attention_mask', 'rejected_labels',\n"\ @@ -1015,8 +1015,6 @@ def patch_sft_trainer_tokenizer(): check_text = "\n".join(" "*where + x for x in check_text) function = function.replace(replacer, check_text + replacer) - print(function) - raise exec(function, globals()) exec(f"trl.trainer.{path_to_trainer}.{function_name} = {function_name}", globals()) From 8fe9109431d9b13757429879338546ffba94ccf0 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 11 Nov 2024 00:04:02 -0800 Subject: [PATCH 197/209] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 57829e621d..ed95e07632 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -588,15 +588,21 @@ def load_correct_tokenizer( def _fix_chat_template(chat_template): endfor = "{% endfor %}" where = chat_template.find(endfor) - if where == -1: return chat_template + if where == -1: + endfor = "{%- endfor %}" + where = chat_template.find(endfor) + if where == -1: + return chat_template after_endfor = chat_template[where + len(endfor):] - if "{% if" not in after_endfor and "{% set " not in after_endfor and \ + dash = "-" if endfor.startswith("{%-") else "" + + if "{%" + dash + " if" not in after_endfor and "{%" + dash + " set " not in after_endfor and \ after_endfor.startswith("{{") and after_endfor.endswith("}}") and \ after_endfor.count("{{") == 1 and after_endfor.count("}}") == 1: - after_endfor = "{% if add_generation_prompt %}" + after_endfor + "{% endif %}" + after_endfor = "{%" + dash + " if add_generation_prompt %}" + after_endfor + endfor chat_template = chat_template[:where + len(endfor)] + after_endfor pass @@ -643,10 +649,12 @@ def fix_chat_template(tokenizer): if no == yes: # SAME?! That's not good! We check for add_generation_prompt - if "{% if add_generation_prompt %}" not in chat_template: + if "{% if add_generation_prompt %}" not in chat_template and \ + "{%- if add_generation_prompt %}" not in chat_template: # Try fixing it by adding it new_chat_template = _fix_chat_template(chat_template) - if "{% if add_generation_prompt %}" not in new_chat_template: + if "{% if add_generation_prompt %}" not in new_chat_template and \ + "{%- if add_generation_prompt %}" not in new_chat_template: raise RuntimeError( f"Unsloth: The tokenizer `{tokenizer.name_or_path}`\n"\ "does not have a {% if add_generation_prompt %} for generation purposes.\n"\ From ad41479c5488cf69c79f1af32ea1d53bce0a08e3 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 11 Nov 2024 00:17:22 -0800 Subject: [PATCH 198/209] triton_cast --- unsloth/kernels/cross_entropy_loss.py | 8 ++++---- unsloth/kernels/utils.py | 6 ++++++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index f82defd405..d347cd1878 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -15,7 +15,7 @@ import triton import triton.language as tl import torch -from .utils import calculate_settings, MAX_FUSED_SIZE, triton_tanh +from .utils import calculate_settings, MAX_FUSED_SIZE, triton_tanh, triton_cast from transformers.models.llama.modeling_llama import logger from packaging.version import Version @@ -64,7 +64,7 @@ def _cross_entropy_forward( This ensures exp(x - max(x))'s maximum is 1 as exp(0) = 1. """ row_idx = tl.program_id(0) - logits_ptr += row_idx * tl.cast(logits_row_stride, tl.int64) + logits_ptr += row_idx * triton_cast(logits_row_stride, tl.int64) loss_ptr += row_idx logsumexp_ptr += row_idx labels_ptr += row_idx @@ -142,7 +142,7 @@ def _chunked_cross_entropy_forward( """ row_idx = tl.program_id(0) chunk_idx = tl.program_id(1) - logits_ptr += row_idx * tl.cast(logits_row_stride, tl.int64) + logits_ptr += row_idx * triton_cast(logits_row_stride, tl.int64) loss_ptr += row_idx logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx labels_ptr += row_idx @@ -216,7 +216,7 @@ def _cross_entropy_backward( row_idx = tl.program_id(0) block_idx = tl.program_id(1) - logits_ptr += row_idx * tl.cast(logits_row_stride, tl.int64) + logits_ptr += row_idx * triton_cast(logits_row_stride, tl.int64) dloss_ptr += row_idx * dloss_row_stride col_offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = col_offsets < VOCAB_SIZE diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index b394d122fd..cef6ccb864 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -34,9 +34,15 @@ if Version(triton.__version__) >= Version("3.0.0"): from triton.language.extra import libdevice triton_tanh = libdevice.tanh + triton_cast = tl.cast else: import triton.language as tl triton_tanh = tl.math.tanh + # No casting in old Triton versions + @triton.jit + def triton_cast(x, dtype): + return x.to(dtype) + pass pass From fcf200997a9af81a9dd2799e584e622c5adee02d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 11 Nov 2024 00:37:37 -0800 Subject: [PATCH 199/209] Update utils.py --- unsloth/kernels/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index cef6ccb864..de543962ef 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -31,12 +31,12 @@ # tl.math.tanh now is libdevice.tanh from packaging.version import Version import triton +import triton.language as tl if Version(triton.__version__) >= Version("3.0.0"): from triton.language.extra import libdevice triton_tanh = libdevice.tanh triton_cast = tl.cast else: - import triton.language as tl triton_tanh = tl.math.tanh # No casting in old Triton versions @triton.jit From af9ba073da5b9c3dbcf6580d32faa9d3bd37a290 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Mon, 11 Nov 2024 18:46:05 -0800 Subject: [PATCH 200/209] Qwen 2.5 Coder --- unsloth/models/llama.py | 3 ++- unsloth/models/mapper.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 7f07bea4c5..47a57024a2 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -2317,7 +2317,8 @@ def patch_peft_model( layer.self_attn.apply_qkv = apply_lora_qkv n_qkv += 1 else: - if model_type != "qwen2": + if model_type == "qwen2": n_qkv += 1 + else: logger.warning_once( "Not an error, but Unsloth cannot patch Attention layers with our manual autograd engine since either LoRA adapters\n"\ "are not enabled or a bias term (like in Qwen) is used." diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index 10e40ab7c6..d4f1278e1d 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -384,22 +384,54 @@ "unsloth/Qwen2.5-Math-72B-Instruct", "Qwen/Qwen2.5-Math-72B-Instruct", ), + "unsloth/Qwen2.5-Coder-0.5B-bnb-4bit" : ( + "unsloth/Qwen2.5-Coder-0.5B", + "Qwen/Qwen2.5-Coder-0.5B", + ), "unsloth/Qwen2.5-Coder-1.5B-bnb-4bit" : ( "unsloth/Qwen2.5-Coder-1.5B", "Qwen/Qwen2.5-Coder-1.5B", ), + "unsloth/Qwen2.5-Coder-3B-bnb-4bit" : ( + "unsloth/Qwen2.5-Coder-3B", + "Qwen/Qwen2.5-Coder-3B", + ), "unsloth/Qwen2.5-Coder-7B-bnb-4bit" : ( "unsloth/Qwen2.5-Coder-7B", "Qwen/Qwen2.5-Coder-7B", ), + "unsloth/Qwen2.5-Coder-14B-bnb-4bit" : ( + "unsloth/Qwen2.5-Coder-14B", + "Qwen/Qwen2.5-Coder-14B", + ), + "unsloth/Qwen2.5-Coder-32B-bnb-4bit" : ( + "unsloth/Qwen2.5-Coder-32B", + "Qwen/Qwen2.5-Coder-32B", + ), + "unsloth/Qwen2.5-Coder-0.5B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-Coder-Instruct-0.5B", + "Qwen/Qwen2.5-Coder-Instruct-0.5B", + ), "unsloth/Qwen2.5-Coder-1.5B-Instruct-bnb-4bit" : ( "unsloth/Qwen2.5-Coder-Instruct-1.5B", "Qwen/Qwen2.5-Coder-Instruct-1.5B", ), + "unsloth/Qwen2.5-Coder-3B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-Coder-3B-Instruct", + "Qwen/Qwen2.5-Coder-3B-Instruct", + ), "unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit" : ( "unsloth/Qwen2.5-Coder-7B-Instruct", "Qwen/Qwen2.5-Coder-7B-Instruct", ), + "unsloth/Qwen2.5-Coder-14B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-Coder-14B-Instruct", + "Qwen/Qwen2.5-Coder-14B-Instruct", + ), + "unsloth/Qwen2.5-Coder-32B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-Coder-32B-Instruct", + "Qwen/Qwen2.5-Coder-32B-Instruct", + ), "unsloth/Llama-3.2-1B-bnb-4bit" : ( "unsloth/Llama-3.2-1B", "meta-llama/Llama-3.2-1B", From 3fec577bb24d64bdf4ace32c91d10eaf0588a330 Mon Sep 17 00:00:00 2001 From: Edd <68678137+Erland366@users.noreply.github.com> Date: Wed, 13 Nov 2024 11:53:50 +0400 Subject: [PATCH 201/209] Fix/export mistral (#1281) * Enhance install_python_non_blocking to handle protobuf installation and process management * Revert "Enhance install_python_non_blocking to handle protobuf installation and process management" This reverts commit f09974b151df1a6ce4708bc4cf75e5eb6b024aed. * Set PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION to 'python' to address issue #1266 * Revert "Set PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION to 'python' to address issue #1266" This reverts commit 9fc130785dac65e9469306f71c666c155add53f1. * Set PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION to 'python' to address issue #1266 * Update __init__.py --------- Co-authored-by: Daniel Han --- unsloth/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 5102d8f466..d5651d5edb 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -31,6 +31,10 @@ # enabling it will require much more work, so we have to prioritize. Please understand! # We do have a beta version, which you can contact us about! # Thank you for your understanding and we appreciate it immensely! + +# Fixes https://github.com/unslothai/unsloth/issues/1266 +os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" + if "CUDA_VISIBLE_DEVICES" in os.environ: os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" devices = os.environ["CUDA_VISIBLE_DEVICES"] From 03c624375f3b1e1e3772cd20a41d865ae1db3d82 Mon Sep 17 00:00:00 2001 From: Uday Girish Maradana Date: Wed, 13 Nov 2024 02:55:28 -0500 Subject: [PATCH 202/209] DOC Update - Update README.md with os.environ in example (#1269) * Update README.md with os.environ in example Added OS Environ in example to avoid device conflicts , for a user at least in jupyter notebook this allows to select GPU in a multi GPU setup. As currently the unsloth init checks all GPU's and takes the first in the order which can be a issue when some GPU's are in use and the list still shows them. So to manually avoid this, this os config is required. Small change but a bit time saver for those who straight away copies the tutorials * Update README.md --------- Co-authored-by: Daniel Han --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index 0a3c83f3fa..4d68d996f0 100644 --- a/README.md +++ b/README.md @@ -299,6 +299,9 @@ DPO (Direct Preference Optimization), PPO, Reward Modelling all seem to work as We're in 🤗Hugging Face's official docs! We're on the [SFT docs](https://huggingface.co/docs/trl/main/en/sft_trainer#accelerate-fine-tuning-2x-using-unsloth) and the [DPO docs](https://huggingface.co/docs/trl/main/en/dpo_trainer#accelerate-dpo-fine-tuning-using-unsloth)! ```python +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Optional set GPU device ID + from unsloth import FastLanguageModel, PatchDPOTrainer from unsloth import is_bfloat16_supported PatchDPOTrainer() From 10565efe27beffe9bd73827bb892b15aba62eb68 Mon Sep 17 00:00:00 2001 From: Edd <68678137+Erland366@users.noreply.github.com> Date: Wed, 13 Nov 2024 12:06:48 +0400 Subject: [PATCH 203/209] fix/get_chat_template (#1246) * Refactor `get_chat_template` to now support system message instead. It supposed to fix ollama tokenizer chattemplate to * Remove type hinting * Update chat_templates.py --------- Co-authored-by: Daniel Han --- unsloth/chat_templates.py | 97 +++++++++++++++++++++++++++++++++++---- 1 file changed, 89 insertions(+), 8 deletions(-) diff --git a/unsloth/chat_templates.py b/unsloth/chat_templates.py index b254202c75..da10f7e003 100644 --- a/unsloth/chat_templates.py +++ b/unsloth/chat_templates.py @@ -39,6 +39,7 @@ train_on_responses_only, ) CHAT_TEMPLATES = {} +DEFAULT_SYSTEM_MESSAGE = {} # =========================================== Unsloth # Unsloth efficient template leverages from Zephyr @@ -48,7 +49,7 @@ "{{ messages[0]['content'] + '\n' }}"\ "{% set loop_messages = messages[1:] %}"\ "{% else %}"\ - "{{ 'You are a helpful assistant to the user\n' }}"\ + "{{ '{system_message}' + '\n' }}"\ "{% set loop_messages = messages %}"\ "{% endif %}"\ "{% for message in loop_messages %}"\ @@ -80,6 +81,7 @@ unsloth_eos_token = "eos_token" CHAT_TEMPLATES["unsloth"] = (unsloth_template, unsloth_eos_token, False, unsloth_ollama,) +DEFAULT_SYSTEM_MESSAGE["unsloth"] = "You are a helpful assistant to the user" pass # =========================================== Zephyr @@ -116,6 +118,7 @@ zephyr_eos_token = "eos_token" CHAT_TEMPLATES["zephyr"] = (zephyr_template, zephyr_eos_token, False, zephyr_ollama,) +DEFAULT_SYSTEM_MESSAGE["zephyr"] = None # No system message in Zephyr pass # =========================================== ChatML @@ -153,6 +156,7 @@ chatml_eos_token = "<|im_end|>" CHAT_TEMPLATES["chatml"] = (chatml_template, chatml_eos_token, True, chatml_ollama,) +DEFAULT_SYSTEM_MESSAGE["chatml"] = None # No system message in ChatML pass # =========================================== Mistral-1 @@ -193,6 +197,7 @@ mistral_eos_token = "eos_token" CHAT_TEMPLATES["mistral"] = (mistral_template, mistral_eos_token, False, mistral_ollama,) +DEFAULT_SYSTEM_MESSAGE["mistral"] = None # No system message in Mistral pass # =========================================== Llama-2 @@ -234,6 +239,7 @@ llama_eos_token = "eos_token" CHAT_TEMPLATES["llama"] = (llama_template, llama_eos_token, False, llama_ollama,) +DEFAULT_SYSTEM_MESSAGE["llama"] = None # No system message in Llama pass # =========================================== Vicuna @@ -244,7 +250,7 @@ "{{ messages[0]['content'] + ' ' }}"\ "{% set loop_messages = messages[1:] %}"\ "{% else %}"\ - "{{ 'A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user\\'s questions.' + ' ' }}"\ + "{{ '{system_message}' + ' ' }}"\ "{% set loop_messages = messages %}"\ "{% endif %}"\ "{% for message in loop_messages %}"\ @@ -273,6 +279,7 @@ vicuna_eos_token = "eos_token" CHAT_TEMPLATES["vicuna"] = (vicuna_template, vicuna_eos_token, False, vicuna_ollama,) +DEFAULT_SYSTEM_MESSAGE["vicuna"] = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions." pass # =========================================== Vicuna Old @@ -283,7 +290,7 @@ "{{ messages[0]['content'] + '\n' }}"\ "{% set loop_messages = messages[1:] %}"\ "{% else %}"\ - "{{ 'A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human\\'s questions.' + '\n' }}"\ + "{{ '{system_message}' + '\n' }}"\ "{% set loop_messages = messages %}"\ "{% endif %}"\ "{% for message in loop_messages %}"\ @@ -315,6 +322,10 @@ vicuna_old_eos_token = "eos_token" CHAT_TEMPLATES["vicuna_old"] = (vicuna_old_template, vicuna_old_eos_token, False, vicuna_old_ollama,) +DEFAULT_SYSTEM_MESSAGE["vicuna_old"] = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human\\'s questions." + +CHAT_TEMPLATES["vicuna old"] = CHAT_TEMPLATES["vicuna_old"] +DEFAULT_SYSTEM_MESSAGE["vicuna old"] = DEFAULT_SYSTEM_MESSAGE["vicuna_old"] pass # =========================================== Alpaca multi turn @@ -325,7 +336,7 @@ "{{ messages[0]['content'] + '\n\n' }}"\ "{% set loop_messages = messages[1:] %}"\ "{% else %}"\ - "{{ 'Below are some instructions that describe some tasks. Write responses that appropriately complete each request.\n\n' }}"\ + "{{ '{system_message}' + '\n\n' }}"\ "{% set loop_messages = messages %}"\ "{% endif %}"\ "{% for message in loop_messages %}"\ @@ -362,6 +373,7 @@ alpaca_eos_token = "eos_token" CHAT_TEMPLATES["alpaca"] = (alpaca_template, alpaca_eos_token, False, alpaca_ollama,) +DEFAULT_SYSTEM_MESSAGE["alpaca"] = "Below are some instructions that describe some tasks. Write responses that appropriately complete each request." pass # =========================================== Gemma @@ -372,7 +384,7 @@ "{{ bos_token }}"\ "{% if messages[0]['role'] == 'system' %}"\ "{{'user\n' + messages[0]['content'] | trim + ' ' + messages[1]['content'] | trim + '\n'}}"\ - "{% set loop_messages = messages[2:] %}"\ + "{% set messages = messages[2:] %}"\ "{% endif %}"\ "{% for message in messages %}"\ "{% if message['role'] == 'user' %}"\ @@ -407,6 +419,7 @@ gemma_eos_token = "" CHAT_TEMPLATES["gemma"] = (gemma_template, gemma_eos_token, True, gemma_ollama,) +DEFAULT_SYSTEM_MESSAGE["gemma"] = None # No system message in Gemma pass # =========================================== Gemma with ChatML instead @@ -437,6 +450,7 @@ "<|im_end|>", ) CHAT_TEMPLATES["gemma_chatml"] = (gemma_chatml_template, gemma_chatml_eos_token, True, gemma_chatml_ollama,) +DEFAULT_SYSTEM_MESSAGE["gemma_chatml"] = None # No system message in Gemma pass # =========================================== Gemma 2 @@ -446,12 +460,14 @@ gemma2_ollama = gemma_ollama + "PARAMETER num_ctx 4096\n" gemma2_eos_token = "" CHAT_TEMPLATES["gemma2"] = (gemma2_template, gemma2_eos_token, True, gemma2_ollama,) +DEFAULT_SYSTEM_MESSAGE["gemma2"] = None # No system message in Gemma 2 # =========================================== Gemma 2 with ChatML instead gemma2_chatml_template = gemma_chatml_template gemma2_chatml_ollama = gemma_chatml_ollama + "PARAMETER num_ctx 4096\n" gemma2_chatml_eos_token = gemma_chatml_eos_token CHAT_TEMPLATES["gemma2_chatml"] = (gemma2_chatml_template, gemma2_chatml_eos_token, True, gemma2_chatml_ollama,) +DEFAULT_SYSTEM_MESSAGE["gemma2_chatml"] = None # No system message in Gemma 2 pass # =========================================== Llama-3 @@ -491,7 +507,12 @@ ''' llama3_template_eos_token = "eos_token" + CHAT_TEMPLATES["llama-3"] = (llama3_template, llama3_template_eos_token, False, llama3_ollama,) +DEFAULT_SYSTEM_MESSAGE["llama-3"] = None # No system message in Llama-3 + +CHAT_TEMPLATES["llama3"] = (llama3_template, llama3_template_eos_token, False, llama3_ollama,) +DEFAULT_SYSTEM_MESSAGE["llama3"] = None # No system message in Llama-3 pass @@ -532,8 +553,13 @@ phi3_template_eos_token = "<|end|>" CHAT_TEMPLATES["phi-3"] = (phi3_template, phi3_template_eos_token, False, phi3_ollama,) +DEFAULT_SYSTEM_MESSAGE["phi-3"] = None # No system message in Phi-3 + CHAT_TEMPLATES["phi-35"] = CHAT_TEMPLATES["phi-3"] +DEFAULT_SYSTEM_MESSAGE["phi-35"] = None # No system message in Phi-3.5 + CHAT_TEMPLATES["phi-3.5"] = CHAT_TEMPLATES["phi-3"] +DEFAULT_SYSTEM_MESSAGE["phi-3.5"] = None # No system message in Phi-3.5 pass # =========================================== Llama-3.1 @@ -573,7 +599,7 @@ {%- set system_message = messages[0]['content'] %} {%- set messages = messages[1:] %} {%- else %} - {%- set system_message = "" %} + {%- set system_message = "{system_message}" %} {%- endif %} {#- System message + builtin tools #} @@ -729,7 +755,10 @@ llama31_template_eos_token = "eos_token" CHAT_TEMPLATES["llama-3.1"] = (llama31_template, llama31_template_eos_token, False, llama31_ollama,) +DEFAULT_SYSTEM_MESSAGE["llama-3.1"] = "" # Llama3.1 default system message is empty + the dates + CHAT_TEMPLATES["llama-31"] = (llama31_template, llama31_template_eos_token, False, llama31_ollama,) +DEFAULT_SYSTEM_MESSAGE["llama-31"] = "" # Llama3.1 default system message is empty + the dates pass @@ -751,7 +780,7 @@ {%- if messages[0][\'role\'] == \'system\' %} {{- \'<|im_start|>system\\n\' + messages[0][\'content\'] + \'<|im_end|>\\n\' }} {%- else %} - {{- \'<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n\' }} + {{- \'<|im_start|>system\\n{system_message}<|im_end|>\\n\' }} {%- endif %}\n{%- endif %}\n{%- for message in messages %} {%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %} {{- \'<|im_start|>\' + message.role + \'\\n\' + message.content + \'<|im_end|>\' + \'\\n\' }} @@ -847,10 +876,53 @@ ''' qwen25_template_eos_token = "eos_token" +qwen25_default_system_message = "You are Qwen, created by Alibaba Cloud. You are a helpful assistant." CHAT_TEMPLATES["qwen-2.5"] = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,) +DEFAULT_SYSTEM_MESSAGE["qwen-2.5"] = qwen25_default_system_message # No system message in Qwen 2.5 + CHAT_TEMPLATES["qwen-25"] = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,) +DEFAULT_SYSTEM_MESSAGE["qwen-25"] = qwen25_default_system_message # No system message in Qwen 2.5 + CHAT_TEMPLATES["qwen25"] = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,) +DEFAULT_SYSTEM_MESSAGE["qwen25"] = qwen25_default_system_message # No system message in Qwen 2.5 + CHAT_TEMPLATES["qwen2.5"] = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,) +DEFAULT_SYSTEM_MESSAGE["qwen2.5"] = qwen25_default_system_message # No system message in Qwen 2.5 +pass + +def _change_system_message(template: str, type_chat_template: str, system_message: str = None): + system_message_pattern = r"\{system_message\}" + + # For predefined templates, check if default system message exists + default_system_message = DEFAULT_SYSTEM_MESSAGE.get(f"{type_chat_template}", None) + if default_system_message is None: + if system_message is not None: + logger.warning_once( + f"Unsloth: You tried to change the system message for {type_chat_template}, " + "but it doesn't have a default system message. " + "You need to manually add the system message in your data." + ) + return template, system_message + pass + + # For custom templates + if type_chat_template is None: + has_placeholder = re.search(system_message_pattern, template) is not None + + if has_placeholder: + if system_message is None: + raise ValueError("Unsloth: You need to provide a system message for custom templates.") + new_template = re.sub(system_message_pattern, system_message, template) + return new_template, system_message + + return template, system_message + pass + + # For predefined templates with default system message + message_to_use = system_message if system_message is not None else default_system_message + new_template = re.sub(system_message_pattern, message_to_use, template) + + return new_template, message_to_use pass @@ -886,14 +958,20 @@ def get_chat_template( old_padding_side = tokenizer.padding_side same_padding_token = False - + type_chat_template = None + if type(chat_template) in (list, tuple,): + # For changing system message later + # Since it's not supported yet, we will raise an error first! + type_chat_template = chat_template[0].lower() chat_template, stop_word = chat_template assert(type(chat_template) is str) assert(type(stop_word) is str) ollama_modelfile = None elif type(chat_template) is str: + # For changing system message later + type_chat_template = chat_template.lower() chat_template, stop_word, yes_map_eos_token, ollama_modelfile = CHAT_TEMPLATES[chat_template] @@ -1052,6 +1130,9 @@ def get_chat_template( else: chat_template = new_chat_template pass + + chat_template, system_message = _change_system_message(chat_template, type_chat_template, system_message) + tokenizer.chat_template = chat_template # Also fix up other tokens From dc0232c883b8f0ac44b59bc4ae917f73402448f5 Mon Sep 17 00:00:00 2001 From: Edd <68678137+Erland366@users.noreply.github.com> Date: Thu, 14 Nov 2024 05:33:30 +0400 Subject: [PATCH 204/209] fix/sft-trainer (#1276) * Add patch for SFTTrainer to maintain backward compatibility with TRL changes * Update trainer.py * Update trainer.py * Refactor trainer patch to maintain backward compatibility with TRL changes * Update trainer.py * Refactor trainer.py to exclude non-convertible trainers from backward compatibility patch --------- Co-authored-by: Daniel Han --- unsloth/__init__.py | 3 ++ unsloth/trainer.py | 112 ++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 105 insertions(+), 10 deletions(-) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index d5651d5edb..d94eeb8973 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -176,3 +176,6 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 from .chat_templates import * from .tokenizer_utils import * from .trainer import * + +# patch sft trainer +_patch_trl_trainer() \ No newline at end of file diff --git a/unsloth/trainer.py b/unsloth/trainer.py index 25bb434023..b03cc347c6 100644 --- a/unsloth/trainer.py +++ b/unsloth/trainer.py @@ -12,9 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from dataclasses import dataclass, field from typing import Optional +from functools import wraps +import trl +import inspect from trl import SFTTrainer try: from trl import SFTConfig as TrainingArguments @@ -24,30 +28,38 @@ from . import is_bfloat16_supported from unsloth_zoo.training_utils import unsloth_train as _unsloth_train from packaging.version import Version +import dataclasses + +__all__ = [ + "UnslothTrainingArguments", + "UnslothTrainer", + "unsloth_train", + "_patch_trl_trainer", +] # Unsloth gradient accumulation fix: from transformers import __version__ as transformers_version if Version(transformers_version) > Version("4.45.2"): - def unsloth_train(trainer): - return trainer.train() + def unsloth_train(trainer, *args, **kwargs): + return trainer.train(*args, **kwargs) pass else: - def unsloth_train(trainer): + def unsloth_train(trainer, *args, **kwargs): + if len(args) != 0 or len(kwargs) != 0: + raise RuntimeError( + "Unsloth: Our custom gradient accumulation fixed trainer does not support other arguments.\n"\ + "If you want to use our fix inside of HF, please update `transformers` to the latest version via:\n"\ + '`pip uninstall transformers -y && pip install --upgrade --no-cache-dir transformers`' + ) print( "Unsloth: Using our custom gradient accumulation fixed trainer, which is not feature complete.\n"\ "If you want to use our fix inside of HF, please update `transformers` to the latest version via:\n"\ - '`pip uninstall transformers -y && pip install --upgrade --no-cache-dir "git+https://github.com/huggingface/transformers.git"`' + '`pip uninstall transformers -y && pip install --upgrade --no-cache-dir transformers`' ) return _unsloth_train(trainer) pass pass -__all__ = [ - "UnslothTrainingArguments", - "UnslothTrainer", - "unsloth_train", -] - @dataclass class UnslothTrainingArguments(TrainingArguments): @@ -119,3 +131,83 @@ def create_optimizer(self): return self.optimizer pass pass + +# From `trl>=0.13.0`, they changed how to pass several params to the trainer +# We need to patch to make the transition smooth +def create_backwards_compatible_trainer(trainer_class, config_class): + original_init = trainer_class.__init__ + + @wraps(original_init) + def new_init(self, *args, **kwargs): + # All Trainer tokenizer is now called processing_class + if "tokenizer" in kwargs: + kwargs["processing_class"] = kwargs.pop("tokenizer") + + if "args" in kwargs: + training_args = kwargs.pop("args", None) + + # Get parameters that Trainer.__init__ actually expects + trainer_params = set(inspect.signature(original_init).parameters.keys()) + trainer_params.remove('self') + trainer_params.remove('args') + + # Get fields that should be passed to Config init + config_fields = { + field.name: field for field in dataclasses.fields(config_class) + if field.init + } + + # Create config dict with valid fields from training_args + config_dict = { + name: getattr(training_args, name) + for name in config_fields + if hasattr(training_args, name) + } + + # Get parameters that exist in Config but not in TrainingArguments + moved_params = \ + set(inspect.signature(config_class) .parameters.keys()) - \ + set(inspect.signature(TrainingArguments).parameters.keys()) + + # Separate kwargs into trainer kwargs and config kwargs + trainer_kwargs = {} + additional_config_kwargs = {} + + for key, value in kwargs.items(): + if key in trainer_params: trainer_kwargs[key] = value + elif key in moved_params or key in config_fields: + additional_config_kwargs[key] = value + else: + additional_config_kwargs[key] = value + pass + + # Update config_dict with additional kwargs + config_dict.update(additional_config_kwargs) + + # Create Config with all the collected parameters + config = config_class(**config_dict) + + # Reconstruct kwargs for Trainer + kwargs = trainer_kwargs + kwargs["args"] = config + pass + original_init(self, *args, **kwargs) + pass + return new_init + +if Version(trl.__version__) >= Version("0.13.0.dev0"): + # print("Patching TRL Trainer to maintain backward compatibility with the old syntax.") + def _patch_trl_trainer(): + import trl.trainer + trl_classes = dir(trl.trainer) + + non_convertable_trainer = set(["PPOv2", "AlignProp"]) + trl_trainers = set(x[:-len("Trainer")] for x in trl_classes if x.endswith("Trainer")) - non_convertable_trainer + trl_configs = set(x[:-len("Config")] for x in trl_classes if x.endswith("Config")) - non_convertable_trainer + trl_classes = list(trl_trainers & trl_configs) + for x in trl_classes: + exec(f"trl.{x}Trainer.__init__ = create_backwards_compatible_trainer(trl.{x}Trainer, trl.{x}Config)", globals()) + pass +else: + def _patch_trl_trainer(): return +pass From 84d6d36cdeb8f6b1fc7b659e208fd8dd7f01d76d Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 13 Nov 2024 17:38:26 -0800 Subject: [PATCH 205/209] Update __init__.py --- unsloth/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index d94eeb8973..745b210208 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -177,5 +177,5 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 from .tokenizer_utils import * from .trainer import * -# patch sft trainer -_patch_trl_trainer() \ No newline at end of file +# Patch TRL trainers for backwards compatibility +_patch_trl_trainer() From a31027c9e4c433f5376eeb07e664120475a821d6 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 13 Nov 2024 18:44:54 -0800 Subject: [PATCH 206/209] Update trainer.py --- unsloth/trainer.py | 43 +++++++++++++++++++++++++------------------ 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/unsloth/trainer.py b/unsloth/trainer.py index b03cc347c6..896db35ecd 100644 --- a/unsloth/trainer.py +++ b/unsloth/trainer.py @@ -139,15 +139,17 @@ def create_backwards_compatible_trainer(trainer_class, config_class): @wraps(original_init) def new_init(self, *args, **kwargs): - # All Trainer tokenizer is now called processing_class - if "tokenizer" in kwargs: + # All Trainer tokenizer are now called processing_class + trainer_params = set(inspect.signature(original_init).parameters.keys()) + + if "processing_class" in trainer_params and "tokenizer" in kwargs: kwargs["processing_class"] = kwargs.pop("tokenizer") + pass - if "args" in kwargs: + if ("args" in kwargs) and (Version(trl.__version__) >= Version("0.13.0.dev0")): training_args = kwargs.pop("args", None) # Get parameters that Trainer.__init__ actually expects - trainer_params = set(inspect.signature(original_init).parameters.keys()) trainer_params.remove('self') trainer_params.remove('args') @@ -179,6 +181,7 @@ def new_init(self, *args, **kwargs): additional_config_kwargs[key] = value else: additional_config_kwargs[key] = value + pass pass # Update config_dict with additional kwargs @@ -194,20 +197,24 @@ def new_init(self, *args, **kwargs): original_init(self, *args, **kwargs) pass return new_init +pass + + +def _patch_trl_trainer(): + if hasattr(trl, "__UNSLOTH_BACKWARDS_COMPATIBLE__"): return + if Version(trl.__version__) <= Version("0.11.0"): return -if Version(trl.__version__) >= Version("0.13.0.dev0"): - # print("Patching TRL Trainer to maintain backward compatibility with the old syntax.") - def _patch_trl_trainer(): - import trl.trainer - trl_classes = dir(trl.trainer) - - non_convertable_trainer = set(["PPOv2", "AlignProp"]) - trl_trainers = set(x[:-len("Trainer")] for x in trl_classes if x.endswith("Trainer")) - non_convertable_trainer - trl_configs = set(x[:-len("Config")] for x in trl_classes if x.endswith("Config")) - non_convertable_trainer - trl_classes = list(trl_trainers & trl_configs) - for x in trl_classes: - exec(f"trl.{x}Trainer.__init__ = create_backwards_compatible_trainer(trl.{x}Trainer, trl.{x}Config)", globals()) + import trl.trainer + trl_classes = dir(trl.trainer) + + non_convertable_trainer = set(["PPOv2", "AlignProp"]) + trl_trainers = set(x[:-len("Trainer")] for x in trl_classes if x.endswith("Trainer")) - non_convertable_trainer + trl_configs = set(x[:-len("Config")] for x in trl_classes if x.endswith("Config")) - non_convertable_trainer + trl_classes = list(trl_trainers & trl_configs) + + for x in trl_classes: + exec(f"trl.{x}Trainer.__init__ = create_backwards_compatible_trainer(trl.{x}Trainer, trl.{x}Config)", globals()) pass -else: - def _patch_trl_trainer(): return + + trl.__UNSLOTH_BACKWARDS_COMPATIBLE__ = True pass From 035bccee7560558f0de37066e5d7b7b294eab939 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 13 Nov 2024 18:48:59 -0800 Subject: [PATCH 207/209] Update trainer.py --- unsloth/trainer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/unsloth/trainer.py b/unsloth/trainer.py index 896db35ecd..25b8999959 100644 --- a/unsloth/trainer.py +++ b/unsloth/trainer.py @@ -201,7 +201,10 @@ def new_init(self, *args, **kwargs): def _patch_trl_trainer(): - if hasattr(trl, "__UNSLOTH_BACKWARDS_COMPATIBLE__"): return + try: + trl.__UNSLOTH_BACKWARDS_COMPATIBLE__ + return + except: pass if Version(trl.__version__) <= Version("0.11.0"): return import trl.trainer From 597169c14bfe28e9484066a80a09a14e47528519 Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 13 Nov 2024 18:53:40 -0800 Subject: [PATCH 208/209] Update trainer.py --- unsloth/trainer.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/unsloth/trainer.py b/unsloth/trainer.py index 25b8999959..00956ed41b 100644 --- a/unsloth/trainer.py +++ b/unsloth/trainer.py @@ -201,10 +201,8 @@ def new_init(self, *args, **kwargs): def _patch_trl_trainer(): - try: - trl.__UNSLOTH_BACKWARDS_COMPATIBLE__ - return - except: pass + import trl + if hasattr(trl, "__UNSLOTH_BACKWARDS_COMPATIBLE__"): return if Version(trl.__version__) <= Version("0.11.0"): return import trl.trainer From 11b350f7c17a6f313915fb71778b40ae35a34ddb Mon Sep 17 00:00:00 2001 From: Daniel Han Date: Wed, 13 Nov 2024 19:05:15 -0800 Subject: [PATCH 209/209] Update tokenizer_utils.py --- unsloth/tokenizer_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index ed95e07632..302017d566 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -586,10 +586,10 @@ def load_correct_tokenizer( def _fix_chat_template(chat_template): - endfor = "{% endfor %}" + endfor = "{% endif %}" where = chat_template.find(endfor) if where == -1: - endfor = "{%- endfor %}" + endfor = "{%- endif %}" where = chat_template.find(endfor) if where == -1: return chat_template