Skip to content
Merged
33 changes: 28 additions & 5 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1697,7 +1697,8 @@ def apply_chat_template(
max_length: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_dict: bool = False,
**tokenizer_kwargs,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
Comment on lines +1704 to +1705
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a breaking change? 👀

Previously passed "tokenizer kwargs" will now be passed to "kwargs"

) -> Union[str, List[int]]:
"""
Converts a Conversation object or a list of dictionaries with `"role"` and `"content"` keys to a list of token
Expand Down Expand Up @@ -1732,7 +1733,8 @@ def apply_chat_template(
- `'jax'`: Return JAX `jnp.ndarray` objects.
return_dict (`bool`, *optional*, defaults to `False`):
Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`.
**tokenizer_kwargs: Additional kwargs to pass to the tokenizer.
tokenizer_kwargs (`Dict[str: Any]`, *optional*): Additional kwargs to pass to the tokenizer.
**kwargs: Additional kwargs to pass to the template renderer. Will be accessible by the chat template.

Returns:
`List[int]`: A list of token ids representing the tokenized chat so far, including control tokens. This
Expand All @@ -1743,8 +1745,28 @@ def apply_chat_template(
# Indicates it's a Conversation object
conversation = conversation.messages

# priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template`
if chat_template is None:
if tokenizer_kwargs is None:
tokenizer_kwargs = {}

# First, handle the cases when the model has a dict of multiple templates
if isinstance(self.chat_template, dict) or (
self.chat_template is None and isinstance(self.default_chat_template, dict)
):
template_dict = self.chat_template or self.default_chat_template
if chat_template is not None and chat_template in template_dict:
# The user can pass the name of a template to the chat template argument instead of an entire template
chat_template = template_dict[chat_template]
Comment on lines +1759 to +1762
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point, isn't chat_template the actual chat template that's passed as input?

chat_template (str, *optional*): A Jinja template to use for this conversion.

It's being used as the key in the template_dict, but what if the user had passed a chat_template to be used as a template?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like if self.chat_template is a dict, when passing in chat_template it is always considered a key when it could be a template. If self.chat_template isn't a dict and self.default_chat_template isn't either, then chat_template is now considered a chat template

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit hacky, but chat_template is only used as a key if it exists as a key in the template_dict. If the user passes an actual Jinja template, that almost certainly will not exist as a key, and so chat_template is treated as a template string.

elif chat_template is None and "default" in template_dict:
chat_template = template_dict["default"]
elif chat_template is None:
raise ValueError(
"This model has multiple chat templates with no default specified! Please either pass a chat "
"template or the name of the template you wish to use to the `chat_template` argument. Available "
f"template names are {sorted(template_dict.keys())}."
)
elif chat_template is None:
# These are the cases when the model has a single template
# priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template
if self.chat_template is not None:
chat_template = self.chat_template
else:
Expand All @@ -1753,8 +1775,9 @@ def apply_chat_template(
# Compilation function uses a cache to avoid recompiling the same template
compiled_template = self._compile_jinja_template(chat_template)

template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present
rendered = compiled_template.render(
messages=conversation, add_generation_prompt=add_generation_prompt, **self.special_tokens_map
messages=conversation, add_generation_prompt=add_generation_prompt, **template_kwargs
)

if padding is True:
Expand Down
18 changes: 18 additions & 0 deletions tests/test_tokenization_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,6 +1118,24 @@ def test_chat_template(self):
self.assertEqual(output, expected_output) # Test output is the same after reloading
tokenizer.apply_chat_template(dummy_conversation, tokenize=True) # Check that no error raised

@require_jinja
def test_chat_template_dict(self):
dummy_template_1 = "{{'a'}}"
dummy_template_2 = "{{'b'}}"
dummy_conversation = [
{"role": "system", "content": "system message"},
{"role": "user", "content": "user message"},
{"role": "assistant", "content": "assistant message"},
]
tokenizer = self.get_tokenizers()[0]
tokenizer.chat_template = {"template1": dummy_template_1, "template2": dummy_template_2}
output1 = tokenizer.apply_chat_template(dummy_conversation, chat_template=dummy_template_1, tokenize=False)
output1_via_dict = tokenizer.apply_chat_template(dummy_conversation, chat_template="template1", tokenize=False)
self.assertEqual(output1, output1_via_dict)
output2 = tokenizer.apply_chat_template(dummy_conversation, chat_template=dummy_template_2, tokenize=False)
output2_via_dict = tokenizer.apply_chat_template(dummy_conversation, chat_template="template2", tokenize=False)
self.assertEqual(output2, output2_via_dict)

def test_number_of_added_tokens(self):
tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers:
Expand Down