Skip to content
Merged
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,18 @@ x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
print(f'pip install --upgrade pip && pip install "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git"')
```

### Windows Installation

To run Unsloth directly on Windows:
- Install Triton from this Windows fork and follow the instructions: https://github.com/woct0rdho/triton-windows
- In the SFTTrainer, set `dataset_num_proc=1` to avoid a crashing issue:
```python
trainer = SFTTrainer(
dataset_num_proc=1,
...
)
```

For **advanced installation instructions** or if you see weird errors during installations:

1. Install `torch` and `triton`. Go to https://pytorch.org to install it. For example `pip install torch torchvision torchaudio triton`
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ huggingface = [
"wheel>=0.42.0",
"numpy",
"accelerate>=0.34.1",
"trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,<=0.11.1",
"trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3",
"peft>=0.7.1,!=0.11.0",
"protobuf<4.0.0",
"huggingface_hub",
Expand Down Expand Up @@ -227,7 +227,7 @@ colab-new = [
]
colab-no-deps = [
"accelerate>=0.34.1",
"trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,<=0.11.1",
"trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3",
"peft>=0.7.1",
"xformers<0.0.27",
"bitsandbytes>=0.43.3",
Expand Down
2 changes: 1 addition & 1 deletion unsloth/chat_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,7 @@
{{- end }}
{{- if .Tools }}

You are a helpful assistant with tool calling capabilities. When you receive a tool call response, use the output to format an answer to the orginal use question.
You are a helpful assistant with tool calling capabilities. When you receive a tool call response, use the output to format an answer to the original use question.
{{- end }}
{{- end }}<|eot_id|>
{{- range $i, $_ := .Messages }}
Expand Down
2 changes: 1 addition & 1 deletion unsloth/kernels/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
create_block_mask as _create_block_mask,
)
_flex_attention = torch.compile(_flex_attention, dynamic = True, options = torch_compile_options)
HAS_FLEX_ATTENTION = True
HAS_FLEX_ATTENTION = False
except:
HAS_FLEX_ATTENTION = False
pass
Expand Down
6 changes: 3 additions & 3 deletions unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "2024.10.3"
__version__ = "2024.10.4"

__all__ = [
"prepare_model_for_kbit_training",
Expand Down Expand Up @@ -1194,8 +1194,8 @@ def patch_gradient_accumulation_fix(Trainer):
logger.warning_once(
"Unsloth: We fixed a gradient accumulation bug, "\
"but it seems like you don't have the latest transformers version!\n"\
"Please update transformers via:\n"\
'`pip uninstall transformers -y && pip install --upgrade --no-cache-dir "git+https://github.com/huggingface/transformers.git"`'
"Please update transformers, TRL and unsloth via:\n"\
'`pip install --upgrade --no-cache-dir unsloth git+https://github.com/huggingface/transformers.git git+https://github.com/huggingface/trl.git`'
)
pass
pass
5 changes: 3 additions & 2 deletions unsloth/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,9 @@ def MistralForCausalLM_fast_forward(

shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))
loss = fast_cross_entropy_loss(
logits = shift_logits,
labels = shift_labels,
logits = shift_logits,
labels = shift_labels,
n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None),
)
pass

Expand Down
45 changes: 43 additions & 2 deletions unsloth/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,7 @@ def check_nvidia():
PRE_CHECK = check_nvidia()


import inspect
from inspect import getsource
import trl.trainer.sft_trainer
from trl.trainer.sft_trainer import *
Expand Down Expand Up @@ -869,6 +870,35 @@ def neftune_post_forward_hook(module, input, output):
pass


def patch_trl_tokenizer_processing_class(trainer_name):
# New TRL removes tokenizer!
# We return it back!
exec(f"from trl import {trainer_name}", globals())
if str(eval(f"{trainer_name}").__name__).startswith("Unsloth"): return None
parameters = eval(f"inspect.signature({trainer_name}).parameters")
if "tokenizer" in parameters: return None

args = {
key : \
value.default \
if type(value.default) is not str else \
f"'{value.default}'" \
for key, value in parameters.items()
}
args["tokenizer"] = None
new_args = args.copy()
del new_args["tokenizer"]
del new_args["processing_class"]
new_args = ",\n".join(f"{' '*12}{key} = {key}" for key in new_args) + \
f",\n{' '*12}processing_class = tokenizer if tokenizer else processing_class"
args = ",\n".join(f"{' '*8}{key} = {value}" for key, value in args.items())
args = f"def __init__(\n" + f"{' '*8}self,\n" + args + "):"
args += f"\n{' '*8}\n{' '*8}super().__init__(\n{new_args}\n{' '*8})"
new_class = f"""class Unsloth{trainer_name}({trainer_name}):\n{' '*4}{args}\n"""
return new_class
pass


def patch_sft_trainer_tokenizer():
"""
Patches the trainer with changes
Expand All @@ -884,7 +914,8 @@ def patch_sft_trainer_tokenizer():

check_text = \
"\n"\
"test_text = dataset[0][dataset_text_field] if (formatting_func is None or not use_formatting_func) else formatting_func(dataset[0])[0]\n"\
"if 'tokenizer' not in locals(): tokenizer = processing_class\n"\
"test_text = dataset[0][dataset_text_field] if (formatting_func is not None and dataset_text_field is None) else formatting_func(dataset[0])[0]\n"\
"chat_template = getattr(tokenizer, 'chat_template', None)\n"\
"chat_template = '' if chat_template is None else chat_template\n"\
"has_bos_token_already = (test_text.startswith(tokenizer.bos_token) or tokenizer.bos_token in chat_template) "\
Expand Down Expand Up @@ -941,7 +972,8 @@ def patch_sft_trainer_tokenizer():
" from transformers import __version__ as transformers_version\n"\
" from packaging.version import Version\n"\
" if Version(transformers_version) <= Version('4.45.2'):\n"\
" print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers and Unsloth!')\n"\
" print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\\n'\\\n"\
" '`pip install --upgrade --no-cache-dir unsloth git+https://github.com/huggingface/transformers.git git+https://github.com/huggingface/trl.git`')\n"\
"except:\n"\
" pass\n"\
"\n\n"
Expand Down Expand Up @@ -981,4 +1013,13 @@ def patch_sft_trainer_tokenizer():
pass
pass

# Fix TRL trainers with removed tokenizer args (got replaced with processing_class)
for trainer_name in ("SFTTrainer", "DPOTrainer", "KTOTrainer"):
trainer_text = patch_trl_tokenizer_processing_class(trainer_name)
if trainer_text is None: continue
exec(trainer_text, globals())
exec(f"trl.trainer.{trainer_name} = Unsloth{trainer_name}", globals())
pass

# FInally patch TRL tokenizer things
patch_sft_trainer_tokenizer()