diff --git a/README.md b/README.md index 6b583d5fd5..d0e5a97999 100644 --- a/README.md +++ b/README.md @@ -181,6 +181,18 @@ x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "") print(f'pip install --upgrade pip && pip install "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git"') ``` +### Windows Installation + +To run Unsloth directly on Windows: +- Install Triton from this Windows fork and follow the instructions: https://github.com/woct0rdho/triton-windows +- In the SFTTrainer, set `dataset_num_proc=1` to avoid a crashing issue: +```python +trainer = SFTTrainer( + dataset_num_proc=1, + ... +) +``` + For **advanced installation instructions** or if you see weird errors during installations: 1. Install `torch` and `triton`. Go to https://pytorch.org to install it. For example `pip install torch torchvision torchaudio triton` diff --git a/pyproject.toml b/pyproject.toml index 455f8477e8..a2a9c2c939 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ huggingface = [ "wheel>=0.42.0", "numpy", "accelerate>=0.34.1", - "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,<=0.11.1", + "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3", "peft>=0.7.1,!=0.11.0", "protobuf<4.0.0", "huggingface_hub", @@ -227,7 +227,7 @@ colab-new = [ ] colab-no-deps = [ "accelerate>=0.34.1", - "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,<=0.11.1", + "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3", "peft>=0.7.1", "xformers<0.0.27", "bitsandbytes>=0.43.3", diff --git a/unsloth/chat_templates.py b/unsloth/chat_templates.py index cab6130dd1..b254202c75 100644 --- a/unsloth/chat_templates.py +++ b/unsloth/chat_templates.py @@ -678,7 +678,7 @@ {{- end }} {{- if .Tools }} -You are a helpful assistant with tool calling capabilities. When you receive a tool call response, use the output to format an answer to the orginal use question. +You are a helpful assistant with tool calling capabilities. When you receive a tool call response, use the output to format an answer to the original use question. {{- end }} {{- end }}<|eot_id|> {{- range $i, $_ := .Messages }} diff --git a/unsloth/kernels/flex_attention.py b/unsloth/kernels/flex_attention.py index 2fba359b77..08426b69e0 100644 --- a/unsloth/kernels/flex_attention.py +++ b/unsloth/kernels/flex_attention.py @@ -31,7 +31,7 @@ create_block_mask as _create_block_mask, ) _flex_attention = torch.compile(_flex_attention, dynamic = True, options = torch_compile_options) - HAS_FLEX_ATTENTION = True + HAS_FLEX_ATTENTION = False except: HAS_FLEX_ATTENTION = False pass diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 4aa8e1e592..25024be2ec 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2024.10.3" +__version__ = "2024.10.4" __all__ = [ "prepare_model_for_kbit_training", @@ -1194,8 +1194,8 @@ def patch_gradient_accumulation_fix(Trainer): logger.warning_once( "Unsloth: We fixed a gradient accumulation bug, "\ "but it seems like you don't have the latest transformers version!\n"\ - "Please update transformers via:\n"\ - '`pip uninstall transformers -y && pip install --upgrade --no-cache-dir "git+https://github.com/huggingface/transformers.git"`' + "Please update transformers, TRL and unsloth via:\n"\ + '`pip install --upgrade --no-cache-dir unsloth git+https://github.com/huggingface/transformers.git git+https://github.com/huggingface/trl.git`' ) pass pass diff --git a/unsloth/models/mistral.py b/unsloth/models/mistral.py index 15e9efc426..00dcc5cd1d 100644 --- a/unsloth/models/mistral.py +++ b/unsloth/models/mistral.py @@ -254,8 +254,9 @@ def MistralForCausalLM_fast_forward( shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]])) loss = fast_cross_entropy_loss( - logits = shift_logits, - labels = shift_labels, + logits = shift_logits, + labels = shift_labels, + n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None), ) pass diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index ffe9933f47..1cad00d44d 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -831,6 +831,7 @@ def check_nvidia(): PRE_CHECK = check_nvidia() +import inspect from inspect import getsource import trl.trainer.sft_trainer from trl.trainer.sft_trainer import * @@ -869,6 +870,35 @@ def neftune_post_forward_hook(module, input, output): pass +def patch_trl_tokenizer_processing_class(trainer_name): + # New TRL removes tokenizer! + # We return it back! + exec(f"from trl import {trainer_name}", globals()) + if str(eval(f"{trainer_name}").__name__).startswith("Unsloth"): return None + parameters = eval(f"inspect.signature({trainer_name}).parameters") + if "tokenizer" in parameters: return None + + args = { + key : \ + value.default \ + if type(value.default) is not str else \ + f"'{value.default}'" \ + for key, value in parameters.items() + } + args["tokenizer"] = None + new_args = args.copy() + del new_args["tokenizer"] + del new_args["processing_class"] + new_args = ",\n".join(f"{' '*12}{key} = {key}" for key in new_args) + \ + f",\n{' '*12}processing_class = tokenizer if tokenizer else processing_class" + args = ",\n".join(f"{' '*8}{key} = {value}" for key, value in args.items()) + args = f"def __init__(\n" + f"{' '*8}self,\n" + args + "):" + args += f"\n{' '*8}\n{' '*8}super().__init__(\n{new_args}\n{' '*8})" + new_class = f"""class Unsloth{trainer_name}({trainer_name}):\n{' '*4}{args}\n""" + return new_class +pass + + def patch_sft_trainer_tokenizer(): """ Patches the trainer with changes @@ -884,7 +914,8 @@ def patch_sft_trainer_tokenizer(): check_text = \ "\n"\ - "test_text = dataset[0][dataset_text_field] if (formatting_func is None or not use_formatting_func) else formatting_func(dataset[0])[0]\n"\ + "if 'tokenizer' not in locals(): tokenizer = processing_class\n"\ + "test_text = dataset[0][dataset_text_field] if (formatting_func is not None and dataset_text_field is None) else formatting_func(dataset[0])[0]\n"\ "chat_template = getattr(tokenizer, 'chat_template', None)\n"\ "chat_template = '' if chat_template is None else chat_template\n"\ "has_bos_token_already = (test_text.startswith(tokenizer.bos_token) or tokenizer.bos_token in chat_template) "\ @@ -941,7 +972,8 @@ def patch_sft_trainer_tokenizer(): " from transformers import __version__ as transformers_version\n"\ " from packaging.version import Version\n"\ " if Version(transformers_version) <= Version('4.45.2'):\n"\ - " print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers and Unsloth!')\n"\ + " print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\\n'\\\n"\ + " '`pip install --upgrade --no-cache-dir unsloth git+https://github.com/huggingface/transformers.git git+https://github.com/huggingface/trl.git`')\n"\ "except:\n"\ " pass\n"\ "\n\n" @@ -981,4 +1013,13 @@ def patch_sft_trainer_tokenizer(): pass pass +# Fix TRL trainers with removed tokenizer args (got replaced with processing_class) +for trainer_name in ("SFTTrainer", "DPOTrainer", "KTOTrainer"): + trainer_text = patch_trl_tokenizer_processing_class(trainer_name) + if trainer_text is None: continue + exec(trainer_text, globals()) + exec(f"trl.trainer.{trainer_name} = Unsloth{trainer_name}", globals()) +pass + +# FInally patch TRL tokenizer things patch_sft_trainer_tokenizer()