Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ include requirements.txt
include README.md
include LICENSE
include src/setuptools_axolotl_dynamic_dependencies.py
include src/axolotl/utils/chat_templates/templates/*.jinja
recursive-include axolotl *.py
149 changes: 0 additions & 149 deletions src/axolotl/utils/chat_templates.py

This file was deleted.

20 changes: 20 additions & 0 deletions src/axolotl/utils/chat_templates/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""
This module provides functionality for selecting chat templates based on user choices.
These templates are used for formatting messages in a conversation.
"""

from .base import (
_CHAT_TEMPLATES,
extract_chat_template_args,
get_chat_template,
get_chat_template_from_config,
register_chat_template,
)

__all__ = [
"get_chat_template",
"extract_chat_template_args",
"get_chat_template_from_config",
"register_chat_template",
"_CHAT_TEMPLATES",
]
125 changes: 125 additions & 0 deletions src/axolotl/utils/chat_templates/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""
utility functions for chat templates
"""

import os
from typing import TYPE_CHECKING, Any, Dict, Optional

from axolotl.utils.logging import get_logger

if TYPE_CHECKING:
from transformers import PreTrainedTokenizerBase

LOG = get_logger("axolotl.utils.chat_templates")

_JINJA_TEMPLATE_CHOICE = "jinja"
_DEFAULT_TEMPLATE_CHOICE = "tokenizer_default"
_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX = "tokenizer_default_fallback_"

TEMPLATE_DIR = os.path.join(os.path.dirname(__file__), "templates")
_CHAT_TEMPLATES: dict[str, str] = {}
for filename in [f for f in os.listdir(TEMPLATE_DIR) if f.endswith(".jinja")]:
with open(os.path.join(TEMPLATE_DIR, filename), "r", encoding="utf-8") as f:
_CHAT_TEMPLATES[filename[:-6]] = f.read()
Comment on lines +21 to +23

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we run this on global scope? Would this cause any redundant computation whenever a module imports this file?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modules only get executed once on import and then cached



def get_chat_template(
user_choice: str,
jinja_template: str | None = None,
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
) -> str:
"""
Finds the correct chat_template based on the user's choice, jinja_template, and tokenizer.

Args:
user_choice (str): The user's choice of template.
jinja_template (str, optional): The jinja template string or Path to a valid jinja template file. Defaults to None.
tokenizer (PreTrainedTokenizerBase, optional): The tokenizer. Defaults to None.

Returns:
str: The chosen template string.

Raises:
ValueError: If the user_choice is not found in the templates.
"""
if user_choice == _JINJA_TEMPLATE_CHOICE:
if not jinja_template:
raise ValueError(
f"`jinja_template` cannot be None when `chat_template` choice is {_JINJA_TEMPLATE_CHOICE}"
)
if os.path.exists(jinja_template) and os.path.isfile(jinja_template):
with open(jinja_template, "r", encoding="utf-8") as file:
jinja_template = file.read()
return jinja_template

if user_choice == _DEFAULT_TEMPLATE_CHOICE:
if not tokenizer:
raise ValueError(
f"`tokenizer` cannot be None when chat_template choice is {_DEFAULT_TEMPLATE_CHOICE}"
)
if not tokenizer.chat_template:
raise ValueError(
f"`chat_template choice is {_DEFAULT_TEMPLATE_CHOICE} but tokenizer's chat_template is null. "
f"Please add a chat_template in tokenizer config"
)
return tokenizer.chat_template # type: ignore

if user_choice.startswith(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX):
if not tokenizer:
raise ValueError(
f"`tokenizer` cannot be None when chat_template choice starts with {_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX}"
)
if tokenizer.chat_template:
return tokenizer.chat_template # type: ignore

user_choice = user_choice[
len(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX) :

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really sure what this is doing?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of this is copy paste when migrating the file to new module. This I believe handles some sort of magic string prefix but don't recall offhand.

]
LOG.warning(
f"No chat template found on tokenizer, falling back to {user_choice}. It is recommended to set --train_on_inputs to True for the model to learn this chat template."
Comment on lines +67 to +79

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Mutating user_choice hampers error messages & logging

Inside the fallback branch you overwrite user_choice after trimming the prefix.
Keep the original value for clearer error reporting:

-        user_choice = user_choice[
-            len(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX) :
-        ]
+        fallback_name = user_choice[len(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX) :]

…and use fallback_name afterwards.

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In src/axolotl/utils/chat_templates/base.py around lines 63 to 75, avoid
mutating the variable user_choice by trimming the prefix directly on it, as this
reduces clarity in error messages and logging. Instead, preserve the original
user_choice value and assign the trimmed result to a new variable fallback_name.
Use fallback_name for subsequent logic and logging to maintain clear and
accurate messages referencing the original user_choice.

)

if user_choice in _CHAT_TEMPLATES:
return _CHAT_TEMPLATES[user_choice]

raise ValueError(f"Template '{user_choice}' not found.")


def extract_chat_template_args(cfg, ds_cfg: Dict[str, Any] | None = None):
if ds_cfg and ds_cfg.get("chat_template"):
chat_template_choice = ds_cfg.get("chat_template") or _DEFAULT_TEMPLATE_CHOICE
chat_template_jinja = ds_cfg.get("chat_template_jinja")
else:
chat_template_choice = cfg.get("chat_template") or _DEFAULT_TEMPLATE_CHOICE
chat_template_jinja = cfg.get("chat_template_jinja")
return chat_template_choice, chat_template_jinja


def get_chat_template_from_config(
cfg,
ds_cfg: Dict[str, Any] | None = None,
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
) -> str:
chat_template_choice, chat_template_jinja = extract_chat_template_args(
cfg=cfg, ds_cfg=ds_cfg
)
return get_chat_template(
user_choice=chat_template_choice,
jinja_template=chat_template_jinja,
tokenizer=tokenizer,
)


def register_chat_template(template_name: str, chat_template: str):
"""
Registers chat templates.

Args:
template_name (str): The name of the template.
chat_template (str): The template string.
"""

if template_name in _CHAT_TEMPLATES:
raise ValueError(f"Template '{template_name}' already exists.")

_CHAT_TEMPLATES[template_name] = chat_template
8 changes: 8 additions & 0 deletions src/axolotl/utils/chat_templates/templates/alpaca.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{{ bos_token }}{% for message in messages %}{% if message['role'] == 'system' and loop.first %}{{ message['content'] }}{% elif message['role'] == 'user' %}{{ '### Instruction:
' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ '### Response:
' + message['content'] + eos_token }}{% endif %}{% if not loop.last %}{{ '

' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '

### Response:
' }}{% endif %}
1 change: 1 addition & 0 deletions src/axolotl/utils/chat_templates/templates/aya.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Aya, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}
4 changes: 4 additions & 0 deletions src/axolotl/utils/chat_templates/templates/chatml.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '
' + message['content'] + '<|im_end|>' + '
'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant
' }}{% endif %}
1 change: 1 addition & 0 deletions src/axolotl/utils/chat_templates/templates/cohere.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}
Loading
Loading