diff --git a/unsloth/chat_templates.py b/unsloth/chat_templates.py index b254202c75..da10f7e003 100644 --- a/unsloth/chat_templates.py +++ b/unsloth/chat_templates.py @@ -39,6 +39,7 @@ train_on_responses_only, ) CHAT_TEMPLATES = {} +DEFAULT_SYSTEM_MESSAGE = {} # =========================================== Unsloth # Unsloth efficient template leverages from Zephyr @@ -48,7 +49,7 @@ "{{ messages[0]['content'] + '\n' }}"\ "{% set loop_messages = messages[1:] %}"\ "{% else %}"\ - "{{ 'You are a helpful assistant to the user\n' }}"\ + "{{ '{system_message}' + '\n' }}"\ "{% set loop_messages = messages %}"\ "{% endif %}"\ "{% for message in loop_messages %}"\ @@ -80,6 +81,7 @@ unsloth_eos_token = "eos_token" CHAT_TEMPLATES["unsloth"] = (unsloth_template, unsloth_eos_token, False, unsloth_ollama,) +DEFAULT_SYSTEM_MESSAGE["unsloth"] = "You are a helpful assistant to the user" pass # =========================================== Zephyr @@ -116,6 +118,7 @@ zephyr_eos_token = "eos_token" CHAT_TEMPLATES["zephyr"] = (zephyr_template, zephyr_eos_token, False, zephyr_ollama,) +DEFAULT_SYSTEM_MESSAGE["zephyr"] = None # No system message in Zephyr pass # =========================================== ChatML @@ -153,6 +156,7 @@ chatml_eos_token = "<|im_end|>" CHAT_TEMPLATES["chatml"] = (chatml_template, chatml_eos_token, True, chatml_ollama,) +DEFAULT_SYSTEM_MESSAGE["chatml"] = None # No system message in ChatML pass # =========================================== Mistral-1 @@ -193,6 +197,7 @@ mistral_eos_token = "eos_token" CHAT_TEMPLATES["mistral"] = (mistral_template, mistral_eos_token, False, mistral_ollama,) +DEFAULT_SYSTEM_MESSAGE["mistral"] = None # No system message in Mistral pass # =========================================== Llama-2 @@ -234,6 +239,7 @@ llama_eos_token = "eos_token" CHAT_TEMPLATES["llama"] = (llama_template, llama_eos_token, False, llama_ollama,) +DEFAULT_SYSTEM_MESSAGE["llama"] = None # No system message in Llama pass # =========================================== Vicuna @@ -244,7 +250,7 @@ "{{ messages[0]['content'] + ' ' }}"\ "{% set loop_messages = messages[1:] %}"\ "{% else %}"\ - "{{ 'A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user\\'s questions.' + ' ' }}"\ + "{{ '{system_message}' + ' ' }}"\ "{% set loop_messages = messages %}"\ "{% endif %}"\ "{% for message in loop_messages %}"\ @@ -273,6 +279,7 @@ vicuna_eos_token = "eos_token" CHAT_TEMPLATES["vicuna"] = (vicuna_template, vicuna_eos_token, False, vicuna_ollama,) +DEFAULT_SYSTEM_MESSAGE["vicuna"] = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions." pass # =========================================== Vicuna Old @@ -283,7 +290,7 @@ "{{ messages[0]['content'] + '\n' }}"\ "{% set loop_messages = messages[1:] %}"\ "{% else %}"\ - "{{ 'A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human\\'s questions.' + '\n' }}"\ + "{{ '{system_message}' + '\n' }}"\ "{% set loop_messages = messages %}"\ "{% endif %}"\ "{% for message in loop_messages %}"\ @@ -315,6 +322,10 @@ vicuna_old_eos_token = "eos_token" CHAT_TEMPLATES["vicuna_old"] = (vicuna_old_template, vicuna_old_eos_token, False, vicuna_old_ollama,) +DEFAULT_SYSTEM_MESSAGE["vicuna_old"] = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human\\'s questions." + +CHAT_TEMPLATES["vicuna old"] = CHAT_TEMPLATES["vicuna_old"] +DEFAULT_SYSTEM_MESSAGE["vicuna old"] = DEFAULT_SYSTEM_MESSAGE["vicuna_old"] pass # =========================================== Alpaca multi turn @@ -325,7 +336,7 @@ "{{ messages[0]['content'] + '\n\n' }}"\ "{% set loop_messages = messages[1:] %}"\ "{% else %}"\ - "{{ 'Below are some instructions that describe some tasks. Write responses that appropriately complete each request.\n\n' }}"\ + "{{ '{system_message}' + '\n\n' }}"\ "{% set loop_messages = messages %}"\ "{% endif %}"\ "{% for message in loop_messages %}"\ @@ -362,6 +373,7 @@ alpaca_eos_token = "eos_token" CHAT_TEMPLATES["alpaca"] = (alpaca_template, alpaca_eos_token, False, alpaca_ollama,) +DEFAULT_SYSTEM_MESSAGE["alpaca"] = "Below are some instructions that describe some tasks. Write responses that appropriately complete each request." pass # =========================================== Gemma @@ -372,7 +384,7 @@ "{{ bos_token }}"\ "{% if messages[0]['role'] == 'system' %}"\ "{{'user\n' + messages[0]['content'] | trim + ' ' + messages[1]['content'] | trim + '\n'}}"\ - "{% set loop_messages = messages[2:] %}"\ + "{% set messages = messages[2:] %}"\ "{% endif %}"\ "{% for message in messages %}"\ "{% if message['role'] == 'user' %}"\ @@ -407,6 +419,7 @@ gemma_eos_token = "" CHAT_TEMPLATES["gemma"] = (gemma_template, gemma_eos_token, True, gemma_ollama,) +DEFAULT_SYSTEM_MESSAGE["gemma"] = None # No system message in Gemma pass # =========================================== Gemma with ChatML instead @@ -437,6 +450,7 @@ "<|im_end|>", ) CHAT_TEMPLATES["gemma_chatml"] = (gemma_chatml_template, gemma_chatml_eos_token, True, gemma_chatml_ollama,) +DEFAULT_SYSTEM_MESSAGE["gemma_chatml"] = None # No system message in Gemma pass # =========================================== Gemma 2 @@ -446,12 +460,14 @@ gemma2_ollama = gemma_ollama + "PARAMETER num_ctx 4096\n" gemma2_eos_token = "" CHAT_TEMPLATES["gemma2"] = (gemma2_template, gemma2_eos_token, True, gemma2_ollama,) +DEFAULT_SYSTEM_MESSAGE["gemma2"] = None # No system message in Gemma 2 # =========================================== Gemma 2 with ChatML instead gemma2_chatml_template = gemma_chatml_template gemma2_chatml_ollama = gemma_chatml_ollama + "PARAMETER num_ctx 4096\n" gemma2_chatml_eos_token = gemma_chatml_eos_token CHAT_TEMPLATES["gemma2_chatml"] = (gemma2_chatml_template, gemma2_chatml_eos_token, True, gemma2_chatml_ollama,) +DEFAULT_SYSTEM_MESSAGE["gemma2_chatml"] = None # No system message in Gemma 2 pass # =========================================== Llama-3 @@ -491,7 +507,12 @@ ''' llama3_template_eos_token = "eos_token" + CHAT_TEMPLATES["llama-3"] = (llama3_template, llama3_template_eos_token, False, llama3_ollama,) +DEFAULT_SYSTEM_MESSAGE["llama-3"] = None # No system message in Llama-3 + +CHAT_TEMPLATES["llama3"] = (llama3_template, llama3_template_eos_token, False, llama3_ollama,) +DEFAULT_SYSTEM_MESSAGE["llama3"] = None # No system message in Llama-3 pass @@ -532,8 +553,13 @@ phi3_template_eos_token = "<|end|>" CHAT_TEMPLATES["phi-3"] = (phi3_template, phi3_template_eos_token, False, phi3_ollama,) +DEFAULT_SYSTEM_MESSAGE["phi-3"] = None # No system message in Phi-3 + CHAT_TEMPLATES["phi-35"] = CHAT_TEMPLATES["phi-3"] +DEFAULT_SYSTEM_MESSAGE["phi-35"] = None # No system message in Phi-3.5 + CHAT_TEMPLATES["phi-3.5"] = CHAT_TEMPLATES["phi-3"] +DEFAULT_SYSTEM_MESSAGE["phi-3.5"] = None # No system message in Phi-3.5 pass # =========================================== Llama-3.1 @@ -573,7 +599,7 @@ {%- set system_message = messages[0]['content'] %} {%- set messages = messages[1:] %} {%- else %} - {%- set system_message = "" %} + {%- set system_message = "{system_message}" %} {%- endif %} {#- System message + builtin tools #} @@ -729,7 +755,10 @@ llama31_template_eos_token = "eos_token" CHAT_TEMPLATES["llama-3.1"] = (llama31_template, llama31_template_eos_token, False, llama31_ollama,) +DEFAULT_SYSTEM_MESSAGE["llama-3.1"] = "" # Llama3.1 default system message is empty + the dates + CHAT_TEMPLATES["llama-31"] = (llama31_template, llama31_template_eos_token, False, llama31_ollama,) +DEFAULT_SYSTEM_MESSAGE["llama-31"] = "" # Llama3.1 default system message is empty + the dates pass @@ -751,7 +780,7 @@ {%- if messages[0][\'role\'] == \'system\' %} {{- \'<|im_start|>system\\n\' + messages[0][\'content\'] + \'<|im_end|>\\n\' }} {%- else %} - {{- \'<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n\' }} + {{- \'<|im_start|>system\\n{system_message}<|im_end|>\\n\' }} {%- endif %}\n{%- endif %}\n{%- for message in messages %} {%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %} {{- \'<|im_start|>\' + message.role + \'\\n\' + message.content + \'<|im_end|>\' + \'\\n\' }} @@ -847,10 +876,53 @@ ''' qwen25_template_eos_token = "eos_token" +qwen25_default_system_message = "You are Qwen, created by Alibaba Cloud. You are a helpful assistant." CHAT_TEMPLATES["qwen-2.5"] = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,) +DEFAULT_SYSTEM_MESSAGE["qwen-2.5"] = qwen25_default_system_message # No system message in Qwen 2.5 + CHAT_TEMPLATES["qwen-25"] = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,) +DEFAULT_SYSTEM_MESSAGE["qwen-25"] = qwen25_default_system_message # No system message in Qwen 2.5 + CHAT_TEMPLATES["qwen25"] = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,) +DEFAULT_SYSTEM_MESSAGE["qwen25"] = qwen25_default_system_message # No system message in Qwen 2.5 + CHAT_TEMPLATES["qwen2.5"] = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,) +DEFAULT_SYSTEM_MESSAGE["qwen2.5"] = qwen25_default_system_message # No system message in Qwen 2.5 +pass + +def _change_system_message(template: str, type_chat_template: str, system_message: str = None): + system_message_pattern = r"\{system_message\}" + + # For predefined templates, check if default system message exists + default_system_message = DEFAULT_SYSTEM_MESSAGE.get(f"{type_chat_template}", None) + if default_system_message is None: + if system_message is not None: + logger.warning_once( + f"Unsloth: You tried to change the system message for {type_chat_template}, " + "but it doesn't have a default system message. " + "You need to manually add the system message in your data." + ) + return template, system_message + pass + + # For custom templates + if type_chat_template is None: + has_placeholder = re.search(system_message_pattern, template) is not None + + if has_placeholder: + if system_message is None: + raise ValueError("Unsloth: You need to provide a system message for custom templates.") + new_template = re.sub(system_message_pattern, system_message, template) + return new_template, system_message + + return template, system_message + pass + + # For predefined templates with default system message + message_to_use = system_message if system_message is not None else default_system_message + new_template = re.sub(system_message_pattern, message_to_use, template) + + return new_template, message_to_use pass @@ -886,14 +958,20 @@ def get_chat_template( old_padding_side = tokenizer.padding_side same_padding_token = False - + type_chat_template = None + if type(chat_template) in (list, tuple,): + # For changing system message later + # Since it's not supported yet, we will raise an error first! + type_chat_template = chat_template[0].lower() chat_template, stop_word = chat_template assert(type(chat_template) is str) assert(type(stop_word) is str) ollama_modelfile = None elif type(chat_template) is str: + # For changing system message later + type_chat_template = chat_template.lower() chat_template, stop_word, yes_map_eos_token, ollama_modelfile = CHAT_TEMPLATES[chat_template] @@ -1052,6 +1130,9 @@ def get_chat_template( else: chat_template = new_chat_template pass + + chat_template, system_message = _change_system_message(chat_template, type_chat_template, system_message) + tokenizer.chat_template = chat_template # Also fix up other tokens