Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 89 additions & 8 deletions unsloth/chat_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
train_on_responses_only,
)
CHAT_TEMPLATES = {}
DEFAULT_SYSTEM_MESSAGE = {}

# =========================================== Unsloth
# Unsloth efficient template leverages from Zephyr
Expand All @@ -48,7 +49,7 @@
"{{ messages[0]['content'] + '\n' }}"\
"{% set loop_messages = messages[1:] %}"\
"{% else %}"\
"{{ 'You are a helpful assistant to the user\n' }}"\
"{{ '{system_message}' + '\n' }}"\
"{% set loop_messages = messages %}"\
"{% endif %}"\
"{% for message in loop_messages %}"\
Expand Down Expand Up @@ -80,6 +81,7 @@

unsloth_eos_token = "eos_token"
CHAT_TEMPLATES["unsloth"] = (unsloth_template, unsloth_eos_token, False, unsloth_ollama,)
DEFAULT_SYSTEM_MESSAGE["unsloth"] = "You are a helpful assistant to the user"
pass

# =========================================== Zephyr
Expand Down Expand Up @@ -116,6 +118,7 @@

zephyr_eos_token = "eos_token"
CHAT_TEMPLATES["zephyr"] = (zephyr_template, zephyr_eos_token, False, zephyr_ollama,)
DEFAULT_SYSTEM_MESSAGE["zephyr"] = None # No system message in Zephyr
pass

# =========================================== ChatML
Expand Down Expand Up @@ -153,6 +156,7 @@

chatml_eos_token = "<|im_end|>"
CHAT_TEMPLATES["chatml"] = (chatml_template, chatml_eos_token, True, chatml_ollama,)
DEFAULT_SYSTEM_MESSAGE["chatml"] = None # No system message in ChatML
pass

# =========================================== Mistral-1
Expand Down Expand Up @@ -193,6 +197,7 @@

mistral_eos_token = "eos_token"
CHAT_TEMPLATES["mistral"] = (mistral_template, mistral_eos_token, False, mistral_ollama,)
DEFAULT_SYSTEM_MESSAGE["mistral"] = None # No system message in Mistral
pass

# =========================================== Llama-2
Expand Down Expand Up @@ -234,6 +239,7 @@

llama_eos_token = "eos_token"
CHAT_TEMPLATES["llama"] = (llama_template, llama_eos_token, False, llama_ollama,)
DEFAULT_SYSTEM_MESSAGE["llama"] = None # No system message in Llama
pass

# =========================================== Vicuna
Expand All @@ -244,7 +250,7 @@
"{{ messages[0]['content'] + ' ' }}"\
"{% set loop_messages = messages[1:] %}"\
"{% else %}"\
"{{ 'A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user\\'s questions.' + ' ' }}"\
"{{ '{system_message}' + ' ' }}"\
"{% set loop_messages = messages %}"\
"{% endif %}"\
"{% for message in loop_messages %}"\
Expand Down Expand Up @@ -273,6 +279,7 @@

vicuna_eos_token = "eos_token"
CHAT_TEMPLATES["vicuna"] = (vicuna_template, vicuna_eos_token, False, vicuna_ollama,)
DEFAULT_SYSTEM_MESSAGE["vicuna"] = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."
pass

# =========================================== Vicuna Old
Expand All @@ -283,7 +290,7 @@
"{{ messages[0]['content'] + '\n' }}"\
"{% set loop_messages = messages[1:] %}"\
"{% else %}"\
"{{ 'A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human\\'s questions.' + '\n' }}"\
"{{ '{system_message}' + '\n' }}"\
"{% set loop_messages = messages %}"\
"{% endif %}"\
"{% for message in loop_messages %}"\
Expand Down Expand Up @@ -315,6 +322,10 @@

vicuna_old_eos_token = "eos_token"
CHAT_TEMPLATES["vicuna_old"] = (vicuna_old_template, vicuna_old_eos_token, False, vicuna_old_ollama,)
DEFAULT_SYSTEM_MESSAGE["vicuna_old"] = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human\\'s questions."

CHAT_TEMPLATES["vicuna old"] = CHAT_TEMPLATES["vicuna_old"]
DEFAULT_SYSTEM_MESSAGE["vicuna old"] = DEFAULT_SYSTEM_MESSAGE["vicuna_old"]
pass

# =========================================== Alpaca multi turn
Expand All @@ -325,7 +336,7 @@
"{{ messages[0]['content'] + '\n\n' }}"\
"{% set loop_messages = messages[1:] %}"\
"{% else %}"\
"{{ 'Below are some instructions that describe some tasks. Write responses that appropriately complete each request.\n\n' }}"\
"{{ '{system_message}' + '\n\n' }}"\
"{% set loop_messages = messages %}"\
"{% endif %}"\
"{% for message in loop_messages %}"\
Expand Down Expand Up @@ -362,6 +373,7 @@

alpaca_eos_token = "eos_token"
CHAT_TEMPLATES["alpaca"] = (alpaca_template, alpaca_eos_token, False, alpaca_ollama,)
DEFAULT_SYSTEM_MESSAGE["alpaca"] = "Below are some instructions that describe some tasks. Write responses that appropriately complete each request."
pass

# =========================================== Gemma
Expand All @@ -372,7 +384,7 @@
"{{ bos_token }}"\
"{% if messages[0]['role'] == 'system' %}"\
"{{'<start_of_turn>user\n' + messages[0]['content'] | trim + ' ' + messages[1]['content'] | trim + '<end_of_turn>\n'}}"\
"{% set loop_messages = messages[2:] %}"\
"{% set messages = messages[2:] %}"\
"{% endif %}"\
"{% for message in messages %}"\
"{% if message['role'] == 'user' %}"\
Expand Down Expand Up @@ -407,6 +419,7 @@

gemma_eos_token = "<end_of_turn>"
CHAT_TEMPLATES["gemma"] = (gemma_template, gemma_eos_token, True, gemma_ollama,)
DEFAULT_SYSTEM_MESSAGE["gemma"] = None # No system message in Gemma
pass

# =========================================== Gemma with ChatML instead
Expand Down Expand Up @@ -437,6 +450,7 @@
"<|im_end|>",
)
CHAT_TEMPLATES["gemma_chatml"] = (gemma_chatml_template, gemma_chatml_eos_token, True, gemma_chatml_ollama,)
DEFAULT_SYSTEM_MESSAGE["gemma_chatml"] = None # No system message in Gemma
pass

# =========================================== Gemma 2
Expand All @@ -446,12 +460,14 @@
gemma2_ollama = gemma_ollama + "PARAMETER num_ctx 4096\n"
gemma2_eos_token = "<end_of_turn>"
CHAT_TEMPLATES["gemma2"] = (gemma2_template, gemma2_eos_token, True, gemma2_ollama,)
DEFAULT_SYSTEM_MESSAGE["gemma2"] = None # No system message in Gemma 2

# =========================================== Gemma 2 with ChatML instead
gemma2_chatml_template = gemma_chatml_template
gemma2_chatml_ollama = gemma_chatml_ollama + "PARAMETER num_ctx 4096\n"
gemma2_chatml_eos_token = gemma_chatml_eos_token
CHAT_TEMPLATES["gemma2_chatml"] = (gemma2_chatml_template, gemma2_chatml_eos_token, True, gemma2_chatml_ollama,)
DEFAULT_SYSTEM_MESSAGE["gemma2_chatml"] = None # No system message in Gemma 2
pass

# =========================================== Llama-3
Expand Down Expand Up @@ -491,7 +507,12 @@
'''

llama3_template_eos_token = "eos_token"

CHAT_TEMPLATES["llama-3"] = (llama3_template, llama3_template_eos_token, False, llama3_ollama,)
DEFAULT_SYSTEM_MESSAGE["llama-3"] = None # No system message in Llama-3

CHAT_TEMPLATES["llama3"] = (llama3_template, llama3_template_eos_token, False, llama3_ollama,)
DEFAULT_SYSTEM_MESSAGE["llama3"] = None # No system message in Llama-3
pass


Expand Down Expand Up @@ -532,8 +553,13 @@

phi3_template_eos_token = "<|end|>"
CHAT_TEMPLATES["phi-3"] = (phi3_template, phi3_template_eos_token, False, phi3_ollama,)
DEFAULT_SYSTEM_MESSAGE["phi-3"] = None # No system message in Phi-3

CHAT_TEMPLATES["phi-35"] = CHAT_TEMPLATES["phi-3"]
DEFAULT_SYSTEM_MESSAGE["phi-35"] = None # No system message in Phi-3.5

CHAT_TEMPLATES["phi-3.5"] = CHAT_TEMPLATES["phi-3"]
DEFAULT_SYSTEM_MESSAGE["phi-3.5"] = None # No system message in Phi-3.5
pass

# =========================================== Llama-3.1
Expand Down Expand Up @@ -573,7 +599,7 @@
{%- set system_message = messages[0]['content'] %}
{%- set messages = messages[1:] %}
{%- else %}
{%- set system_message = "" %}
{%- set system_message = "{system_message}" %}
{%- endif %}

{#- System message + builtin tools #}
Expand Down Expand Up @@ -729,7 +755,10 @@

llama31_template_eos_token = "eos_token"
CHAT_TEMPLATES["llama-3.1"] = (llama31_template, llama31_template_eos_token, False, llama31_ollama,)
DEFAULT_SYSTEM_MESSAGE["llama-3.1"] = "" # Llama3.1 default system message is empty + the dates

CHAT_TEMPLATES["llama-31"] = (llama31_template, llama31_template_eos_token, False, llama31_ollama,)
DEFAULT_SYSTEM_MESSAGE["llama-31"] = "" # Llama3.1 default system message is empty + the dates
pass


Expand All @@ -751,7 +780,7 @@
{%- if messages[0][\'role\'] == \'system\' %}
{{- \'<|im_start|>system\\n\' + messages[0][\'content\'] + \'<|im_end|>\\n\' }}
{%- else %}
{{- \'<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n\' }}
{{- \'<|im_start|>system\\n{system_message}<|im_end|>\\n\' }}
{%- endif %}\n{%- endif %}\n{%- for message in messages %}
{%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %}
{{- \'<|im_start|>\' + message.role + \'\\n\' + message.content + \'<|im_end|>\' + \'\\n\' }}
Expand Down Expand Up @@ -847,10 +876,53 @@
'''

qwen25_template_eos_token = "eos_token"
qwen25_default_system_message = "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
CHAT_TEMPLATES["qwen-2.5"] = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,)
DEFAULT_SYSTEM_MESSAGE["qwen-2.5"] = qwen25_default_system_message # No system message in Qwen 2.5

CHAT_TEMPLATES["qwen-25"] = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,)
DEFAULT_SYSTEM_MESSAGE["qwen-25"] = qwen25_default_system_message # No system message in Qwen 2.5

CHAT_TEMPLATES["qwen25"] = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,)
DEFAULT_SYSTEM_MESSAGE["qwen25"] = qwen25_default_system_message # No system message in Qwen 2.5

CHAT_TEMPLATES["qwen2.5"] = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,)
DEFAULT_SYSTEM_MESSAGE["qwen2.5"] = qwen25_default_system_message # No system message in Qwen 2.5
pass

def _change_system_message(template: str, type_chat_template: str, system_message: str = None):
system_message_pattern = r"\{system_message\}"

# For predefined templates, check if default system message exists
default_system_message = DEFAULT_SYSTEM_MESSAGE.get(f"{type_chat_template}", None)
if default_system_message is None:
if system_message is not None:
logger.warning_once(
f"Unsloth: You tried to change the system message for {type_chat_template}, "
"but it doesn't have a default system message. "
"You need to manually add the system message in your data."
)
return template, system_message
pass

# For custom templates
if type_chat_template is None:
has_placeholder = re.search(system_message_pattern, template) is not None

if has_placeholder:
if system_message is None:
raise ValueError("Unsloth: You need to provide a system message for custom templates.")
new_template = re.sub(system_message_pattern, system_message, template)
return new_template, system_message

return template, system_message
pass

# For predefined templates with default system message
message_to_use = system_message if system_message is not None else default_system_message
new_template = re.sub(system_message_pattern, message_to_use, template)

return new_template, message_to_use
pass


Expand Down Expand Up @@ -886,14 +958,20 @@ def get_chat_template(
old_padding_side = tokenizer.padding_side

same_padding_token = False

type_chat_template = None

if type(chat_template) in (list, tuple,):
# For changing system message later
# Since it's not supported yet, we will raise an error first!
type_chat_template = chat_template[0].lower()
chat_template, stop_word = chat_template
assert(type(chat_template) is str)
assert(type(stop_word) is str)
ollama_modelfile = None

elif type(chat_template) is str:
# For changing system message later
type_chat_template = chat_template.lower()

chat_template, stop_word, yes_map_eos_token, ollama_modelfile = CHAT_TEMPLATES[chat_template]

Expand Down Expand Up @@ -1052,6 +1130,9 @@ def get_chat_template(
else:
chat_template = new_chat_template
pass

chat_template, system_message = _change_system_message(chat_template, type_chat_template, system_message)

tokenizer.chat_template = chat_template

# Also fix up other tokens
Expand Down