-
-
Notifications
You must be signed in to change notification settings - Fork 1.4k
manage jinja templates as nicely formatted files #2795
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
77b729b
b7ac128
d5899fa
63f3186
384f625
e67b133
3d67dfc
bce87fa
5f1b5f8
a06077b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| """ | ||
| This module provides functionality for selecting chat templates based on user choices. | ||
| These templates are used for formatting messages in a conversation. | ||
| """ | ||
|
|
||
| from .base import ( | ||
| _CHAT_TEMPLATES, | ||
| extract_chat_template_args, | ||
| get_chat_template, | ||
| get_chat_template_from_config, | ||
| register_chat_template, | ||
| ) | ||
|
|
||
| __all__ = [ | ||
| "get_chat_template", | ||
| "extract_chat_template_args", | ||
| "get_chat_template_from_config", | ||
| "register_chat_template", | ||
| "_CHAT_TEMPLATES", | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,125 @@ | ||
| """ | ||
| utility functions for chat templates | ||
| """ | ||
|
|
||
| import os | ||
| from typing import TYPE_CHECKING, Any, Dict, Optional | ||
|
|
||
| from axolotl.utils.logging import get_logger | ||
|
|
||
| if TYPE_CHECKING: | ||
| from transformers import PreTrainedTokenizerBase | ||
|
|
||
| LOG = get_logger("axolotl.utils.chat_templates") | ||
|
|
||
| _JINJA_TEMPLATE_CHOICE = "jinja" | ||
| _DEFAULT_TEMPLATE_CHOICE = "tokenizer_default" | ||
| _DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX = "tokenizer_default_fallback_" | ||
|
|
||
| TEMPLATE_DIR = os.path.join(os.path.dirname(__file__), "templates") | ||
| _CHAT_TEMPLATES: dict[str, str] = {} | ||
| for filename in [f for f in os.listdir(TEMPLATE_DIR) if f.endswith(".jinja")]: | ||
| with open(os.path.join(TEMPLATE_DIR, filename), "r", encoding="utf-8") as f: | ||
| _CHAT_TEMPLATES[filename[:-6]] = f.read() | ||
|
|
||
|
|
||
| def get_chat_template( | ||
| user_choice: str, | ||
| jinja_template: str | None = None, | ||
| tokenizer: Optional["PreTrainedTokenizerBase"] = None, | ||
| ) -> str: | ||
| """ | ||
| Finds the correct chat_template based on the user's choice, jinja_template, and tokenizer. | ||
|
|
||
| Args: | ||
| user_choice (str): The user's choice of template. | ||
| jinja_template (str, optional): The jinja template string or Path to a valid jinja template file. Defaults to None. | ||
| tokenizer (PreTrainedTokenizerBase, optional): The tokenizer. Defaults to None. | ||
|
|
||
| Returns: | ||
| str: The chosen template string. | ||
|
|
||
| Raises: | ||
| ValueError: If the user_choice is not found in the templates. | ||
| """ | ||
| if user_choice == _JINJA_TEMPLATE_CHOICE: | ||
| if not jinja_template: | ||
| raise ValueError( | ||
| f"`jinja_template` cannot be None when `chat_template` choice is {_JINJA_TEMPLATE_CHOICE}" | ||
| ) | ||
| if os.path.exists(jinja_template) and os.path.isfile(jinja_template): | ||
| with open(jinja_template, "r", encoding="utf-8") as file: | ||
| jinja_template = file.read() | ||
| return jinja_template | ||
|
|
||
| if user_choice == _DEFAULT_TEMPLATE_CHOICE: | ||
| if not tokenizer: | ||
| raise ValueError( | ||
| f"`tokenizer` cannot be None when chat_template choice is {_DEFAULT_TEMPLATE_CHOICE}" | ||
| ) | ||
| if not tokenizer.chat_template: | ||
| raise ValueError( | ||
| f"`chat_template choice is {_DEFAULT_TEMPLATE_CHOICE} but tokenizer's chat_template is null. " | ||
| f"Please add a chat_template in tokenizer config" | ||
| ) | ||
| return tokenizer.chat_template # type: ignore | ||
|
|
||
| if user_choice.startswith(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX): | ||
| if not tokenizer: | ||
| raise ValueError( | ||
| f"`tokenizer` cannot be None when chat_template choice starts with {_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX}" | ||
| ) | ||
| if tokenizer.chat_template: | ||
| return tokenizer.chat_template # type: ignore | ||
|
|
||
| user_choice = user_choice[ | ||
| len(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX) : | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not really sure what this is doing?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Most of this is copy paste when migrating the file to new module. This I believe handles some sort of magic string prefix but don't recall offhand. |
||
| ] | ||
| LOG.warning( | ||
| f"No chat template found on tokenizer, falling back to {user_choice}. It is recommended to set --train_on_inputs to True for the model to learn this chat template." | ||
|
Comment on lines
+67
to
+79
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Mutating Inside the fallback branch you overwrite - user_choice = user_choice[
- len(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX) :
- ]
+ fallback_name = user_choice[len(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX) :]…and use
🤖 Prompt for AI Agents |
||
| ) | ||
|
|
||
| if user_choice in _CHAT_TEMPLATES: | ||
| return _CHAT_TEMPLATES[user_choice] | ||
|
|
||
| raise ValueError(f"Template '{user_choice}' not found.") | ||
|
|
||
|
|
||
| def extract_chat_template_args(cfg, ds_cfg: Dict[str, Any] | None = None): | ||
| if ds_cfg and ds_cfg.get("chat_template"): | ||
| chat_template_choice = ds_cfg.get("chat_template") or _DEFAULT_TEMPLATE_CHOICE | ||
| chat_template_jinja = ds_cfg.get("chat_template_jinja") | ||
| else: | ||
| chat_template_choice = cfg.get("chat_template") or _DEFAULT_TEMPLATE_CHOICE | ||
| chat_template_jinja = cfg.get("chat_template_jinja") | ||
| return chat_template_choice, chat_template_jinja | ||
|
|
||
|
|
||
| def get_chat_template_from_config( | ||
| cfg, | ||
| ds_cfg: Dict[str, Any] | None = None, | ||
| tokenizer: Optional["PreTrainedTokenizerBase"] = None, | ||
| ) -> str: | ||
| chat_template_choice, chat_template_jinja = extract_chat_template_args( | ||
| cfg=cfg, ds_cfg=ds_cfg | ||
| ) | ||
| return get_chat_template( | ||
| user_choice=chat_template_choice, | ||
| jinja_template=chat_template_jinja, | ||
| tokenizer=tokenizer, | ||
| ) | ||
|
|
||
|
|
||
| def register_chat_template(template_name: str, chat_template: str): | ||
| """ | ||
| Registers chat templates. | ||
|
|
||
| Args: | ||
| template_name (str): The name of the template. | ||
| chat_template (str): The template string. | ||
| """ | ||
|
|
||
| if template_name in _CHAT_TEMPLATES: | ||
| raise ValueError(f"Template '{template_name}' already exists.") | ||
|
|
||
| _CHAT_TEMPLATES[template_name] = chat_template | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| {{ bos_token }}{% for message in messages %}{% if message['role'] == 'system' and loop.first %}{{ message['content'] }}{% elif message['role'] == 'user' %}{{ '### Instruction: | ||
| ' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '### Response: | ||
| ' + message['content'] + eos_token }}{% endif %}{% if not loop.last %}{{ ' | ||
|
|
||
| ' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ ' | ||
|
|
||
| ### Response: | ||
| ' }}{% endif %} |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| {{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Aya, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %} |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| {% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + ' | ||
| ' + message['content'] + '<|im_end|>' + ' | ||
| '}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant | ||
| ' }}{% endif %} |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| {{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we run this on global scope? Would this cause any redundant computation whenever a module imports this file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
modules only get executed once on import and then cached