-
Notifications
You must be signed in to change notification settings - Fork 31.7k
Allow apply_chat_template to pass kwargs to the template and support a dict of templates #29658
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
1e908c7
024f786
0222077
2cb238c
1df5db7
6a39d6c
8d9789c
53ca0fb
6b71e6e
e814006
97442a6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1697,7 +1697,8 @@ def apply_chat_template( | |
| max_length: Optional[int] = None, | ||
| return_tensors: Optional[Union[str, TensorType]] = None, | ||
| return_dict: bool = False, | ||
| **tokenizer_kwargs, | ||
| tokenizer_kwargs: Optional[Dict[str, Any]] = None, | ||
| **kwargs, | ||
| ) -> Union[str, List[int]]: | ||
| """ | ||
| Converts a Conversation object or a list of dictionaries with `"role"` and `"content"` keys to a list of token | ||
|
|
@@ -1732,7 +1733,8 @@ def apply_chat_template( | |
| - `'jax'`: Return JAX `jnp.ndarray` objects. | ||
| return_dict (`bool`, *optional*, defaults to `False`): | ||
| Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`. | ||
| **tokenizer_kwargs: Additional kwargs to pass to the tokenizer. | ||
| tokenizer_kwargs (`Dict[str: Any]`, *optional*): Additional kwargs to pass to the tokenizer. | ||
| **kwargs: Additional kwargs to pass to the template renderer. Will be accessible by the chat template. | ||
|
|
||
| Returns: | ||
| `List[int]`: A list of token ids representing the tokenized chat so far, including control tokens. This | ||
|
|
@@ -1743,8 +1745,28 @@ def apply_chat_template( | |
| # Indicates it's a Conversation object | ||
| conversation = conversation.messages | ||
|
|
||
| # priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template` | ||
| if chat_template is None: | ||
| if tokenizer_kwargs is None: | ||
| tokenizer_kwargs = {} | ||
|
|
||
| # First, handle the cases when the model has a dict of multiple templates | ||
| if isinstance(self.chat_template, dict) or ( | ||
| self.chat_template is None and isinstance(self.default_chat_template, dict) | ||
| ): | ||
| template_dict = self.chat_template or self.default_chat_template | ||
| if chat_template is not None and chat_template in template_dict: | ||
| # The user can pass the name of a template to the chat template argument instead of an entire template | ||
| chat_template = template_dict[chat_template] | ||
|
Comment on lines
+1759
to
+1762
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At this point, isn't
It's being used as the key in the
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems like if
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a bit hacky, but |
||
| elif chat_template is None and "default" in template_dict: | ||
| chat_template = template_dict["default"] | ||
| elif chat_template is None: | ||
| raise ValueError( | ||
| "This model has multiple chat templates with no default specified! Please either pass a chat " | ||
| "template or the name of the template you wish to use to the `chat_template` argument. Available " | ||
| f"template names are {sorted(template_dict.keys())}." | ||
| ) | ||
| elif chat_template is None: | ||
| # These are the cases when the model has a single template | ||
| # priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template | ||
| if self.chat_template is not None: | ||
| chat_template = self.chat_template | ||
| else: | ||
|
|
@@ -1753,8 +1775,9 @@ def apply_chat_template( | |
| # Compilation function uses a cache to avoid recompiling the same template | ||
| compiled_template = self._compile_jinja_template(chat_template) | ||
|
|
||
| template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present | ||
| rendered = compiled_template.render( | ||
| messages=conversation, add_generation_prompt=add_generation_prompt, **self.special_tokens_map | ||
| messages=conversation, add_generation_prompt=add_generation_prompt, **template_kwargs | ||
| ) | ||
|
|
||
| if padding is True: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a breaking change? 👀
Previously passed "tokenizer kwargs" will now be passed to "kwargs"