Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
219 commits
Select commit Hold shift + click to select a range
f0aca90
Fix TRL
danielhanchen Oct 21, 2024
f4ae585
Update mistral.py
danielhanchen Oct 22, 2024
106f213
Patch processing_class
danielhanchen Oct 22, 2024
ef84212
Update tokenizer_utils.py
danielhanchen Oct 22, 2024
4f7c527
Update tokenizer_utils.py
danielhanchen Oct 22, 2024
aa2b207
Update tokenizer_utils.py
danielhanchen Oct 22, 2024
101389d
Update tokenizer_utils.py
danielhanchen Oct 22, 2024
c0f0fc9
Update tokenizer_utils.py
danielhanchen Oct 22, 2024
b3e0033
Update tokenizer_utils.py
danielhanchen Oct 22, 2024
aabb5ff
Installation guide (#1165)
timothelaborie Oct 23, 2024
30bf339
chore: update chat_templates.py (#1166)
eltociear Oct 23, 2024
2895839
Disable Flex Attention
danielhanchen Oct 23, 2024
06f5d75
Update tokenizer_utils.py
danielhanchen Oct 23, 2024
28e6eea
Update _utils.py
danielhanchen Oct 23, 2024
b821f20
n_items
danielhanchen Oct 24, 2024
e561366
Update cross_entropy_loss.py
danielhanchen Oct 24, 2024
4ff247a
Fix DPO, ORPO
danielhanchen Oct 24, 2024
2b858a5
Merge branch 'main' into nightly
danielhanchen Oct 24, 2024
1c063b4
Update _utils.py
danielhanchen Oct 24, 2024
f195ee1
Update _utils.py
danielhanchen Oct 24, 2024
faf2747
fix/transformers-unpack (#1180)
Erland366 Oct 24, 2024
5961c34
Update cross_entropy_loss.py
danielhanchen Oct 24, 2024
7308bb8
Update _utils.py
danielhanchen Oct 24, 2024
0096e5b
Update _utils.py
danielhanchen Oct 24, 2024
44b480f
Merge branch 'main' into nightly
danielhanchen Oct 24, 2024
6776055
donot upcast lm_head and embeddings to float32 (#1186)
Datta0 Oct 25, 2024
625209e
Cleanup upcast logs (#1188)
Datta0 Oct 25, 2024
2bc189f
Fix/phi-longrope (#1193)
Erland366 Oct 25, 2024
6f28d16
Update transformers
danielhanchen Oct 26, 2024
f94f7c1
Merge branch 'main' into nightly
danielhanchen Oct 26, 2024
bf3b175
Merge branch 'main' into nightly
danielhanchen Oct 27, 2024
7083a1d
Unk token issues
danielhanchen Oct 28, 2024
3acc5af
Update _utils.py
danielhanchen Oct 28, 2024
1c044da
Fix pad token
danielhanchen Oct 28, 2024
5286f19
Update llama.py
danielhanchen Oct 28, 2024
02437a8
Typo
danielhanchen Oct 28, 2024
9d07be0
ignored labels
danielhanchen Oct 28, 2024
a8b37a3
Revert "ignored labels"
danielhanchen Oct 28, 2024
2dfdba3
More patching
danielhanchen Oct 28, 2024
5541ab4
Update _utils.py
danielhanchen Oct 28, 2024
c6e9af2
Update _utils.py
danielhanchen Oct 28, 2024
cac56d1
Update cross_entropy_loss.py
danielhanchen Oct 28, 2024
5ee1189
Update cross_entropy_loss.py
danielhanchen Oct 28, 2024
85a5f60
Update cross_entropy_loss.py
danielhanchen Oct 28, 2024
20e38ed
Feat/all tmp (#1219)
danielhanchen Oct 30, 2024
7e1692a
Bug fixes
danielhanchen Oct 30, 2024
6bef8f1
Update pyproject.toml
danielhanchen Oct 30, 2024
9ccbc0e
Update _utils.py
danielhanchen Oct 30, 2024
95ecc57
Update __init__.py
danielhanchen Oct 30, 2024
5f5fef8
Update __init__.py
danielhanchen Oct 30, 2024
784dd13
Update _utils.py
danielhanchen Oct 30, 2024
5b75e21
Update _utils.py
danielhanchen Oct 30, 2024
74ab93c
Update _utils.py
danielhanchen Oct 30, 2024
526505c
Update _utils.py
danielhanchen Oct 30, 2024
251ba77
Update cross_entropy_loss.py
danielhanchen Oct 30, 2024
530c495
Update cross_entropy_loss.py
danielhanchen Oct 30, 2024
07394c3
Update cross_entropy_loss.py
danielhanchen Oct 30, 2024
6d7004b
Update cross_entropy_loss.py
danielhanchen Oct 30, 2024
d86b20a
Update cross_entropy_loss.py
danielhanchen Oct 30, 2024
9920950
Update cross_entropy_loss.py
danielhanchen Oct 30, 2024
9f926ce
Update cross_entropy_loss.py
danielhanchen Oct 30, 2024
30cdf65
Update cross_entropy_loss.py
danielhanchen Oct 30, 2024
54b901b
Update cross_entropy_loss.py
danielhanchen Oct 31, 2024
6db9d28
Update cross_entropy_loss.py
danielhanchen Oct 31, 2024
8aefcd0
Update cross_entropy_loss.py
danielhanchen Oct 31, 2024
7bf626b
Update cross_entropy_loss.py
danielhanchen Oct 31, 2024
d455751
Update cross_entropy_loss.py
danielhanchen Oct 31, 2024
055eeb8
Update cross_entropy_loss.py
danielhanchen Oct 31, 2024
8090b7c
Tied weights
danielhanchen Oct 31, 2024
7559efb
Revert "Tied weights"
danielhanchen Oct 31, 2024
ad63a32
Tied weights
danielhanchen Oct 31, 2024
35aa992
Utils
danielhanchen Nov 3, 2024
0172ee3
CE Loss patching
danielhanchen Nov 3, 2024
c228682
Update __init__.py
danielhanchen Nov 3, 2024
9aa221a
Update __init__.py
danielhanchen Nov 3, 2024
751413e
Patching
danielhanchen Nov 3, 2024
82db087
Update cross_entropy_loss.py
danielhanchen Nov 3, 2024
cf68202
CE Loss
danielhanchen Nov 3, 2024
63a1828
Update _utils.py
danielhanchen Nov 3, 2024
3f0e56f
Update _utils.py
danielhanchen Nov 3, 2024
1190ed4
CE Loss
danielhanchen Nov 3, 2024
607ac34
Update _utils.py
danielhanchen Nov 3, 2024
32eac0b
Update _utils.py
danielhanchen Nov 3, 2024
5b6d401
Layernorm
danielhanchen Nov 4, 2024
3d19a71
Update _utils.py
danielhanchen Nov 4, 2024
76da511
Update _utils.py
danielhanchen Nov 4, 2024
013ebaa
Post patch
danielhanchen Nov 4, 2024
608916a
Update _utils.py
danielhanchen Nov 4, 2024
19836e3
Update llama.py
danielhanchen Nov 4, 2024
0164087
Update _utils.py
danielhanchen Nov 4, 2024
205f7ad
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
2f1f393
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
05b8f66
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
8d205c0
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
a1e9e13
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
94655f8
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
085f998
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
c796fd9
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
e943d77
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
16a7df6
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
f65b064
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
1ff49b8
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
080e558
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
f6d50c7
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
fad4202
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
736b16a
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
eb76416
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
367e43f
typing
danielhanchen Nov 4, 2024
993df20
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
8f566b3
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
22bb46b
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
b5c9f81
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
c7b2220
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
2d0ab26
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
428f662
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
5023ce9
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
5ca3d4a
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
3b32d81
int64
danielhanchen Nov 4, 2024
9bae6e2
Update _utils.py
danielhanchen Nov 4, 2024
5123623
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
4b1d9e2
constexpr
danielhanchen Nov 4, 2024
7d5111a
constexpr
danielhanchen Nov 4, 2024
dff5a52
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
969d1bd
Update cross_entropy_loss.py
danielhanchen Nov 4, 2024
4b5847f
Update _utils.py
danielhanchen Nov 4, 2024
766bf1e
Update _utils.py
danielhanchen Nov 4, 2024
646f1b7
Update _utils.py
danielhanchen Nov 5, 2024
97f37ac
CE
danielhanchen Nov 5, 2024
cc563fa
Update cross_entropy_loss.py
danielhanchen Nov 5, 2024
f643148
Update _utils.py
danielhanchen Nov 5, 2024
f28d7f6
Update llama.py
danielhanchen Nov 5, 2024
d8103e1
Update _utils.py
danielhanchen Nov 5, 2024
b9e1a49
Update rms_layernorm.py
danielhanchen Nov 5, 2024
56af302
Update rms_layernorm.py
danielhanchen Nov 5, 2024
a3c84a3
Update rms_layernorm.py
danielhanchen Nov 5, 2024
f7d5c56
Update rms_layernorm.py
danielhanchen Nov 5, 2024
8496ff6
Update rms_layernorm.py
danielhanchen Nov 5, 2024
2909eaf
Update rms_layernorm.py
danielhanchen Nov 5, 2024
afc8af6
Update utils.py
danielhanchen Nov 5, 2024
2d8d1e1
Update rms_layernorm.py
danielhanchen Nov 5, 2024
ecc1ad2
Update rms_layernorm.py
danielhanchen Nov 5, 2024
ae7cb78
Update rms_layernorm.py
danielhanchen Nov 5, 2024
22da266
Update rms_layernorm.py
danielhanchen Nov 5, 2024
beb6854
Update rms_layernorm.py
danielhanchen Nov 5, 2024
14c3d2f
Update rms_layernorm.py
danielhanchen Nov 5, 2024
ef4b079
Update rms_layernorm.py
danielhanchen Nov 5, 2024
ef684f8
Update rms_layernorm.py
danielhanchen Nov 5, 2024
3e4c42f
Update rms_layernorm.py
danielhanchen Nov 5, 2024
8f825eb
Update rms_layernorm.py
danielhanchen Nov 5, 2024
bd4ac7b
Update rms_layernorm.py
danielhanchen Nov 5, 2024
6f38731
Update rms_layernorm.py
danielhanchen Nov 5, 2024
2df35d4
typing
danielhanchen Nov 5, 2024
74d89d1
Update rope_embedding.py
danielhanchen Nov 5, 2024
98927ee
types
danielhanchen Nov 5, 2024
f3e2bd6
Disable compiling
danielhanchen Nov 5, 2024
c30bd2a
Update _utils.py
danielhanchen Nov 5, 2024
813cbdd
Update _utils.py
danielhanchen Nov 5, 2024
34ce5d1
Forward hook
danielhanchen Nov 5, 2024
f84cf4b
Update _utils.py
danielhanchen Nov 5, 2024
745814c
Update llama.py
danielhanchen Nov 5, 2024
ab9f8e1
Update _utils.py
danielhanchen Nov 5, 2024
daa7909
Update llama.py
danielhanchen Nov 5, 2024
536a1a6
Update llama.py
danielhanchen Nov 5, 2024
648ca59
Update _utils.py
danielhanchen Nov 5, 2024
486d0d6
Update pyproject.toml
danielhanchen Nov 5, 2024
eb4da9d
Update _utils.py
danielhanchen Nov 5, 2024
da397f4
Update llama.py
danielhanchen Nov 5, 2024
70b65cf
CE Loss
danielhanchen Nov 5, 2024
aeec57e
Update cross_entropy_loss.py
danielhanchen Nov 5, 2024
fb393fc
Update _utils.py
danielhanchen Nov 5, 2024
cab1e72
Update cross_entropy_loss.py
danielhanchen Nov 6, 2024
51fea97
Update cross_entropy_loss.py
danielhanchen Nov 6, 2024
58e541b
Update cross_entropy_loss.py
danielhanchen Nov 6, 2024
0ed0532
Merge branch 'main' into nightly
danielhanchen Nov 6, 2024
ef2c56f
Update llama.py
danielhanchen Nov 6, 2024
24ab0d2
Merge branch 'main' into nightly
danielhanchen Nov 6, 2024
13d7412
Update _utils.py
danielhanchen Nov 6, 2024
5a7eaf8
Update _utils.py
danielhanchen Nov 6, 2024
d2186ed
Update _utils.py
danielhanchen Nov 6, 2024
6434447
Update _utils.py
danielhanchen Nov 6, 2024
67611e6
Update _utils.py
danielhanchen Nov 6, 2024
36c5836
Merge branch 'main' into nightly
danielhanchen Nov 6, 2024
f24aef5
Fix: cast logits to float32 in cross_entropy_forward to prevent error…
Erland366 Nov 6, 2024
3d906e6
Throw error when inferencing longer than max_popsition_embeddings (#1…
Datta0 Nov 6, 2024
de1049b
CLI now handles user input strings for dtype correctly (#1235)
Rabbidon Nov 6, 2024
be72975
Update flex_attention.py
danielhanchen Nov 6, 2024
05170cd
Update _utils.py
danielhanchen Nov 6, 2024
7e0877d
Update _utils.py
danielhanchen Nov 6, 2024
6b5c599
Update flex_attention.py
danielhanchen Nov 6, 2024
1ba9f2e
Update flex_attention.py
danielhanchen Nov 6, 2024
da61c4d
Update loader.py
danielhanchen Nov 6, 2024
3316ee2
Update loader.py
danielhanchen Nov 6, 2024
501ca84
Update flex_attention.py
danielhanchen Nov 6, 2024
ce621b7
Update flex_attention.py
danielhanchen Nov 6, 2024
4b01ff1
Update flex_attention.py
danielhanchen Nov 6, 2024
ef5052a
Update flex_attention.py
danielhanchen Nov 7, 2024
52bca32
Update _utils.py
danielhanchen Nov 7, 2024
68b8d62
Merge branch 'main' into nightly
danielhanchen Nov 7, 2024
15da065
Merge branch 'main' into nightly
danielhanchen Nov 7, 2024
8b3e9c2
Update cross_entropy_loss.py
danielhanchen Nov 7, 2024
3a1e7ef
Update _utils.py
danielhanchen Nov 7, 2024
f1ec165
Update tokenizer_utils.py
danielhanchen Nov 10, 2024
a4e9705
Update tokenizer_utils.py
danielhanchen Nov 10, 2024
92c6a27
Update tokenizer_utils.py
danielhanchen Nov 10, 2024
673f541
Update tokenizer_utils.py
danielhanchen Nov 10, 2024
8fe9109
Update tokenizer_utils.py
danielhanchen Nov 11, 2024
ad41479
triton_cast
danielhanchen Nov 11, 2024
fcf2009
Update utils.py
danielhanchen Nov 11, 2024
af9ba07
Qwen 2.5 Coder
danielhanchen Nov 12, 2024
e99acdd
Merge branch 'main' into nightly
danielhanchen Nov 13, 2024
3fec577
Fix/export mistral (#1281)
Erland366 Nov 13, 2024
03c6243
DOC Update - Update README.md with os.environ in example (#1269)
udaygirish Nov 13, 2024
10565ef
fix/get_chat_template (#1246)
Erland366 Nov 13, 2024
dc0232c
fix/sft-trainer (#1276)
Erland366 Nov 14, 2024
84d6d36
Update __init__.py
danielhanchen Nov 14, 2024
a31027c
Update trainer.py
danielhanchen Nov 14, 2024
035bcce
Update trainer.py
danielhanchen Nov 14, 2024
597169c
Update trainer.py
danielhanchen Nov 14, 2024
11b350f
Update tokenizer_utils.py
danielhanchen Nov 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,9 @@ DPO (Direct Preference Optimization), PPO, Reward Modelling all seem to work as
We're in 🤗Hugging Face's official docs! We're on the [SFT docs](https://huggingface.co/docs/trl/main/en/sft_trainer#accelerate-fine-tuning-2x-using-unsloth) and the [DPO docs](https://huggingface.co/docs/trl/main/en/dpo_trainer#accelerate-dpo-fine-tuning-using-unsloth)!

```python
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Optional set GPU device ID

from unsloth import FastLanguageModel, PatchDPOTrainer
from unsloth import is_bfloat16_supported
PatchDPOTrainer()
Expand Down
7 changes: 7 additions & 0 deletions unsloth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
# enabling it will require much more work, so we have to prioritize. Please understand!
# We do have a beta version, which you can contact us about!
# Thank you for your understanding and we appreciate it immensely!

# Fixes https://github.com/unslothai/unsloth/issues/1266
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"

if "CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
devices = os.environ["CUDA_VISIBLE_DEVICES"]
Expand Down Expand Up @@ -172,3 +176,6 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16
from .chat_templates import *
from .tokenizer_utils import *
from .trainer import *

# Patch TRL trainers for backwards compatibility
_patch_trl_trainer()
97 changes: 89 additions & 8 deletions unsloth/chat_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
train_on_responses_only,
)
CHAT_TEMPLATES = {}
DEFAULT_SYSTEM_MESSAGE = {}

# =========================================== Unsloth
# Unsloth efficient template leverages from Zephyr
Expand All @@ -48,7 +49,7 @@
"{{ messages[0]['content'] + '\n' }}"\
"{% set loop_messages = messages[1:] %}"\
"{% else %}"\
"{{ 'You are a helpful assistant to the user\n' }}"\
"{{ '{system_message}' + '\n' }}"\
"{% set loop_messages = messages %}"\
"{% endif %}"\
"{% for message in loop_messages %}"\
Expand Down Expand Up @@ -80,6 +81,7 @@

unsloth_eos_token = "eos_token"
CHAT_TEMPLATES["unsloth"] = (unsloth_template, unsloth_eos_token, False, unsloth_ollama,)
DEFAULT_SYSTEM_MESSAGE["unsloth"] = "You are a helpful assistant to the user"
pass

# =========================================== Zephyr
Expand Down Expand Up @@ -116,6 +118,7 @@

zephyr_eos_token = "eos_token"
CHAT_TEMPLATES["zephyr"] = (zephyr_template, zephyr_eos_token, False, zephyr_ollama,)
DEFAULT_SYSTEM_MESSAGE["zephyr"] = None # No system message in Zephyr
pass

# =========================================== ChatML
Expand Down Expand Up @@ -153,6 +156,7 @@

chatml_eos_token = "<|im_end|>"
CHAT_TEMPLATES["chatml"] = (chatml_template, chatml_eos_token, True, chatml_ollama,)
DEFAULT_SYSTEM_MESSAGE["chatml"] = None # No system message in ChatML
pass

# =========================================== Mistral-1
Expand Down Expand Up @@ -193,6 +197,7 @@

mistral_eos_token = "eos_token"
CHAT_TEMPLATES["mistral"] = (mistral_template, mistral_eos_token, False, mistral_ollama,)
DEFAULT_SYSTEM_MESSAGE["mistral"] = None # No system message in Mistral
pass

# =========================================== Llama-2
Expand Down Expand Up @@ -234,6 +239,7 @@

llama_eos_token = "eos_token"
CHAT_TEMPLATES["llama"] = (llama_template, llama_eos_token, False, llama_ollama,)
DEFAULT_SYSTEM_MESSAGE["llama"] = None # No system message in Llama
pass

# =========================================== Vicuna
Expand All @@ -244,7 +250,7 @@
"{{ messages[0]['content'] + ' ' }}"\
"{% set loop_messages = messages[1:] %}"\
"{% else %}"\
"{{ 'A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user\\'s questions.' + ' ' }}"\
"{{ '{system_message}' + ' ' }}"\
"{% set loop_messages = messages %}"\
"{% endif %}"\
"{% for message in loop_messages %}"\
Expand Down Expand Up @@ -273,6 +279,7 @@

vicuna_eos_token = "eos_token"
CHAT_TEMPLATES["vicuna"] = (vicuna_template, vicuna_eos_token, False, vicuna_ollama,)
DEFAULT_SYSTEM_MESSAGE["vicuna"] = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."
pass

# =========================================== Vicuna Old
Expand All @@ -283,7 +290,7 @@
"{{ messages[0]['content'] + '\n' }}"\
"{% set loop_messages = messages[1:] %}"\
"{% else %}"\
"{{ 'A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human\\'s questions.' + '\n' }}"\
"{{ '{system_message}' + '\n' }}"\
"{% set loop_messages = messages %}"\
"{% endif %}"\
"{% for message in loop_messages %}"\
Expand Down Expand Up @@ -315,6 +322,10 @@

vicuna_old_eos_token = "eos_token"
CHAT_TEMPLATES["vicuna_old"] = (vicuna_old_template, vicuna_old_eos_token, False, vicuna_old_ollama,)
DEFAULT_SYSTEM_MESSAGE["vicuna_old"] = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human\\'s questions."

CHAT_TEMPLATES["vicuna old"] = CHAT_TEMPLATES["vicuna_old"]
DEFAULT_SYSTEM_MESSAGE["vicuna old"] = DEFAULT_SYSTEM_MESSAGE["vicuna_old"]
pass

# =========================================== Alpaca multi turn
Expand All @@ -325,7 +336,7 @@
"{{ messages[0]['content'] + '\n\n' }}"\
"{% set loop_messages = messages[1:] %}"\
"{% else %}"\
"{{ 'Below are some instructions that describe some tasks. Write responses that appropriately complete each request.\n\n' }}"\
"{{ '{system_message}' + '\n\n' }}"\
"{% set loop_messages = messages %}"\
"{% endif %}"\
"{% for message in loop_messages %}"\
Expand Down Expand Up @@ -362,6 +373,7 @@

alpaca_eos_token = "eos_token"
CHAT_TEMPLATES["alpaca"] = (alpaca_template, alpaca_eos_token, False, alpaca_ollama,)
DEFAULT_SYSTEM_MESSAGE["alpaca"] = "Below are some instructions that describe some tasks. Write responses that appropriately complete each request."
pass

# =========================================== Gemma
Expand All @@ -372,7 +384,7 @@
"{{ bos_token }}"\
"{% if messages[0]['role'] == 'system' %}"\
"{{'<start_of_turn>user\n' + messages[0]['content'] | trim + ' ' + messages[1]['content'] | trim + '<end_of_turn>\n'}}"\
"{% set loop_messages = messages[2:] %}"\
"{% set messages = messages[2:] %}"\
"{% endif %}"\
"{% for message in messages %}"\
"{% if message['role'] == 'user' %}"\
Expand Down Expand Up @@ -407,6 +419,7 @@

gemma_eos_token = "<end_of_turn>"
CHAT_TEMPLATES["gemma"] = (gemma_template, gemma_eos_token, True, gemma_ollama,)
DEFAULT_SYSTEM_MESSAGE["gemma"] = None # No system message in Gemma
pass

# =========================================== Gemma with ChatML instead
Expand Down Expand Up @@ -437,6 +450,7 @@
"<|im_end|>",
)
CHAT_TEMPLATES["gemma_chatml"] = (gemma_chatml_template, gemma_chatml_eos_token, True, gemma_chatml_ollama,)
DEFAULT_SYSTEM_MESSAGE["gemma_chatml"] = None # No system message in Gemma
pass

# =========================================== Gemma 2
Expand All @@ -446,12 +460,14 @@
gemma2_ollama = gemma_ollama + "PARAMETER num_ctx 4096\n"
gemma2_eos_token = "<end_of_turn>"
CHAT_TEMPLATES["gemma2"] = (gemma2_template, gemma2_eos_token, True, gemma2_ollama,)
DEFAULT_SYSTEM_MESSAGE["gemma2"] = None # No system message in Gemma 2

# =========================================== Gemma 2 with ChatML instead
gemma2_chatml_template = gemma_chatml_template
gemma2_chatml_ollama = gemma_chatml_ollama + "PARAMETER num_ctx 4096\n"
gemma2_chatml_eos_token = gemma_chatml_eos_token
CHAT_TEMPLATES["gemma2_chatml"] = (gemma2_chatml_template, gemma2_chatml_eos_token, True, gemma2_chatml_ollama,)
DEFAULT_SYSTEM_MESSAGE["gemma2_chatml"] = None # No system message in Gemma 2
pass

# =========================================== Llama-3
Expand Down Expand Up @@ -491,7 +507,12 @@
'''

llama3_template_eos_token = "eos_token"

CHAT_TEMPLATES["llama-3"] = (llama3_template, llama3_template_eos_token, False, llama3_ollama,)
DEFAULT_SYSTEM_MESSAGE["llama-3"] = None # No system message in Llama-3

CHAT_TEMPLATES["llama3"] = (llama3_template, llama3_template_eos_token, False, llama3_ollama,)
DEFAULT_SYSTEM_MESSAGE["llama3"] = None # No system message in Llama-3
pass


Expand Down Expand Up @@ -532,8 +553,13 @@

phi3_template_eos_token = "<|end|>"
CHAT_TEMPLATES["phi-3"] = (phi3_template, phi3_template_eos_token, False, phi3_ollama,)
DEFAULT_SYSTEM_MESSAGE["phi-3"] = None # No system message in Phi-3

CHAT_TEMPLATES["phi-35"] = CHAT_TEMPLATES["phi-3"]
DEFAULT_SYSTEM_MESSAGE["phi-35"] = None # No system message in Phi-3.5

CHAT_TEMPLATES["phi-3.5"] = CHAT_TEMPLATES["phi-3"]
DEFAULT_SYSTEM_MESSAGE["phi-3.5"] = None # No system message in Phi-3.5
pass

# =========================================== Llama-3.1
Expand Down Expand Up @@ -573,7 +599,7 @@
{%- set system_message = messages[0]['content'] %}
{%- set messages = messages[1:] %}
{%- else %}
{%- set system_message = "" %}
{%- set system_message = "{system_message}" %}
{%- endif %}

{#- System message + builtin tools #}
Expand Down Expand Up @@ -729,7 +755,10 @@

llama31_template_eos_token = "eos_token"
CHAT_TEMPLATES["llama-3.1"] = (llama31_template, llama31_template_eos_token, False, llama31_ollama,)
DEFAULT_SYSTEM_MESSAGE["llama-3.1"] = "" # Llama3.1 default system message is empty + the dates

CHAT_TEMPLATES["llama-31"] = (llama31_template, llama31_template_eos_token, False, llama31_ollama,)
DEFAULT_SYSTEM_MESSAGE["llama-31"] = "" # Llama3.1 default system message is empty + the dates
pass


Expand All @@ -751,7 +780,7 @@
{%- if messages[0][\'role\'] == \'system\' %}
{{- \'<|im_start|>system\\n\' + messages[0][\'content\'] + \'<|im_end|>\\n\' }}
{%- else %}
{{- \'<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n\' }}
{{- \'<|im_start|>system\\n{system_message}<|im_end|>\\n\' }}
{%- endif %}\n{%- endif %}\n{%- for message in messages %}
{%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %}
{{- \'<|im_start|>\' + message.role + \'\\n\' + message.content + \'<|im_end|>\' + \'\\n\' }}
Expand Down Expand Up @@ -847,10 +876,53 @@
'''

qwen25_template_eos_token = "eos_token"
qwen25_default_system_message = "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
CHAT_TEMPLATES["qwen-2.5"] = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,)
DEFAULT_SYSTEM_MESSAGE["qwen-2.5"] = qwen25_default_system_message # No system message in Qwen 2.5

CHAT_TEMPLATES["qwen-25"] = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,)
DEFAULT_SYSTEM_MESSAGE["qwen-25"] = qwen25_default_system_message # No system message in Qwen 2.5

CHAT_TEMPLATES["qwen25"] = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,)
DEFAULT_SYSTEM_MESSAGE["qwen25"] = qwen25_default_system_message # No system message in Qwen 2.5

CHAT_TEMPLATES["qwen2.5"] = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,)
DEFAULT_SYSTEM_MESSAGE["qwen2.5"] = qwen25_default_system_message # No system message in Qwen 2.5
pass

def _change_system_message(template: str, type_chat_template: str, system_message: str = None):
system_message_pattern = r"\{system_message\}"

# For predefined templates, check if default system message exists
default_system_message = DEFAULT_SYSTEM_MESSAGE.get(f"{type_chat_template}", None)
if default_system_message is None:
if system_message is not None:
logger.warning_once(
f"Unsloth: You tried to change the system message for {type_chat_template}, "
"but it doesn't have a default system message. "
"You need to manually add the system message in your data."
)
return template, system_message
pass

# For custom templates
if type_chat_template is None:
has_placeholder = re.search(system_message_pattern, template) is not None

if has_placeholder:
if system_message is None:
raise ValueError("Unsloth: You need to provide a system message for custom templates.")
new_template = re.sub(system_message_pattern, system_message, template)
return new_template, system_message

return template, system_message
pass

# For predefined templates with default system message
message_to_use = system_message if system_message is not None else default_system_message
new_template = re.sub(system_message_pattern, message_to_use, template)

return new_template, message_to_use
pass


Expand Down Expand Up @@ -886,14 +958,20 @@ def get_chat_template(
old_padding_side = tokenizer.padding_side

same_padding_token = False

type_chat_template = None

if type(chat_template) in (list, tuple,):
# For changing system message later
# Since it's not supported yet, we will raise an error first!
type_chat_template = chat_template[0].lower()
chat_template, stop_word = chat_template
assert(type(chat_template) is str)
assert(type(stop_word) is str)
ollama_modelfile = None

elif type(chat_template) is str:
# For changing system message later
type_chat_template = chat_template.lower()

chat_template, stop_word, yes_map_eos_token, ollama_modelfile = CHAT_TEMPLATES[chat_template]

Expand Down Expand Up @@ -1052,6 +1130,9 @@ def get_chat_template(
else:
chat_template = new_chat_template
pass

chat_template, system_message = _change_system_message(chat_template, type_chat_template, system_message)

tokenizer.chat_template = chat_template

# Also fix up other tokens
Expand Down
4 changes: 2 additions & 2 deletions unsloth/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,10 +586,10 @@ def load_correct_tokenizer(


def _fix_chat_template(chat_template):
endfor = "{% endfor %}"
endfor = "{% endif %}"
where = chat_template.find(endfor)
if where == -1:
endfor = "{%- endfor %}"
endfor = "{%- endif %}"
where = chat_template.find(endfor)
if where == -1:
return chat_template
Expand Down
Loading