From 1e908c7cebfb4529c3c17d72e5d1bc5fbcd908b7 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 14 Mar 2024 14:28:01 +0000 Subject: [PATCH 01/11] Allow apply_chat_template to pass kwargs to the template --- src/transformers/tokenization_utils_base.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 7e5c21a1bdfb..2357a5db5703 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1697,7 +1697,7 @@ def apply_chat_template( max_length: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_dict: bool = False, - **tokenizer_kwargs, + **kwargs, ) -> Union[str, List[int]]: """ Converts a Conversation object or a list of dictionaries with `"role"` and `"content"` keys to a list of token @@ -1754,7 +1754,7 @@ def apply_chat_template( compiled_template = self._compile_jinja_template(chat_template) rendered = compiled_template.render( - messages=conversation, add_generation_prompt=add_generation_prompt, **self.special_tokens_map + messages=conversation, add_generation_prompt=add_generation_prompt, **self.special_tokens_map, **kwargs ) if padding is True: @@ -1768,7 +1768,7 @@ def apply_chat_template( max_length=max_length, add_special_tokens=False, return_tensors=return_tensors, - **tokenizer_kwargs, + **kwargs, ) else: return self.encode( @@ -1778,7 +1778,7 @@ def apply_chat_template( max_length=max_length, add_special_tokens=False, return_tensors=return_tensors, - **tokenizer_kwargs, + **kwargs, ) else: return rendered From 024f786beb555f986c2529e4143db535d90efe93 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 14 Mar 2024 14:38:11 +0000 Subject: [PATCH 02/11] Fix priority for template_kwargs --- src/transformers/tokenization_utils_base.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 2357a5db5703..703c694de2ea 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1697,6 +1697,7 @@ def apply_chat_template( max_length: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_dict: bool = False, + tokenizer_kwargs: Dict[str, Any] = None, **kwargs, ) -> Union[str, List[int]]: """ @@ -1743,6 +1744,9 @@ def apply_chat_template( # Indicates it's a Conversation object conversation = conversation.messages + if tokenizer_kwargs is None: + tokenizer_kwargs = dict() + # priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template` if chat_template is None: if self.chat_template is not None: @@ -1753,8 +1757,9 @@ def apply_chat_template( # Compilation function uses a cache to avoid recompiling the same template compiled_template = self._compile_jinja_template(chat_template) + template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present rendered = compiled_template.render( - messages=conversation, add_generation_prompt=add_generation_prompt, **self.special_tokens_map, **kwargs + messages=conversation, add_generation_prompt=add_generation_prompt, **template_kwargs ) if padding is True: @@ -1768,7 +1773,7 @@ def apply_chat_template( max_length=max_length, add_special_tokens=False, return_tensors=return_tensors, - **kwargs, + **tokenizer_kwargs, ) else: return self.encode( @@ -1778,7 +1783,7 @@ def apply_chat_template( max_length=max_length, add_special_tokens=False, return_tensors=return_tensors, - **kwargs, + **tokenizer_kwargs, ) else: return rendered From 022207767011a303e6b34ddeb02f773da8ae6e18 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 14 Mar 2024 14:41:09 +0000 Subject: [PATCH 03/11] Fix docstring --- src/transformers/tokenization_utils_base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 703c694de2ea..6d2f1b58fec7 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1697,7 +1697,7 @@ def apply_chat_template( max_length: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_dict: bool = False, - tokenizer_kwargs: Dict[str, Any] = None, + tokenizer_kwargs: Optional[Dict[str, Any]] = None, **kwargs, ) -> Union[str, List[int]]: """ @@ -1733,7 +1733,8 @@ def apply_chat_template( - `'jax'`: Return JAX `jnp.ndarray` objects. return_dict (`bool`, *optional*, defaults to `False`): Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`. - **tokenizer_kwargs: Additional kwargs to pass to the tokenizer. + tokenizer_kwargs (`Dict[str: Any]`, *optional*): Additional kwargs to pass to the tokenizer. + **kwargs: Additional kwargs to pass to the template renderer. Will be accessible by the chat template. Returns: `List[int]`: A list of token ids representing the tokenized chat so far, including control tokens. This From 2cb238cd4a695608feb0f1f6745201f34df83ddb Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 14 Mar 2024 14:41:34 +0000 Subject: [PATCH 04/11] style fix --- src/transformers/tokenization_utils_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 6d2f1b58fec7..225657d1930f 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1746,7 +1746,7 @@ def apply_chat_template( conversation = conversation.messages if tokenizer_kwargs is None: - tokenizer_kwargs = dict() + tokenizer_kwargs = {} # priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template` if chat_template is None: From 1df5db7a8ea33e7f22e48270d0c7bad3bfdba0d6 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 14 Mar 2024 15:01:34 +0000 Subject: [PATCH 05/11] Add the option for the model to have a dict of templates --- src/transformers/tokenization_utils_base.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 225657d1930f..b9b4869a9c79 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1748,8 +1748,25 @@ def apply_chat_template( if tokenizer_kwargs is None: tokenizer_kwargs = {} - # priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template` - if chat_template is None: + # First, handle the cases when the model has a dict of multiple templates + if isinstance(self.chat_template, dict) or ( + self.chat_template is None and isinstance(self.default_chat_template, dict) + ): + template_dict = self.chat_template or self.default_chat_template + if chat_template is not None and chat_template in template_dict: + # The user can pass the name of a template to the chat template argument instead of an entire template + chat_template = template_dict[chat_template] + elif chat_template is None and "default" in template_dict: + chat_template = template_dict["default"] + elif chat_template is None: + raise ValueError( + "This model has multiple chat templates with no default specified! Please pass " + "the name of the template you wish to use to the `chat_template` argument. Available " + f"templates are {list(template_dict.keys())}." + ) + elif chat_template is None: + # These are the cases when the model has a single template + # priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template if self.chat_template is not None: chat_template = self.chat_template else: From 6a39d6c71ccce15a7938abe40e2fded3271057c5 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 14 Mar 2024 15:07:54 +0000 Subject: [PATCH 06/11] Error message cleanup --- src/transformers/tokenization_utils_base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index b9b4869a9c79..f21c1fe3dac3 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1760,9 +1760,9 @@ def apply_chat_template( chat_template = template_dict["default"] elif chat_template is None: raise ValueError( - "This model has multiple chat templates with no default specified! Please pass " - "the name of the template you wish to use to the `chat_template` argument. Available " - f"templates are {list(template_dict.keys())}." + "This model has multiple chat templates with no default specified! Please either pass a chat " + "template or the name of the template you wish to use to the `chat_template` argument. Available " + f"template names are {sorted(template_dict.keys())}." ) elif chat_template is None: # These are the cases when the model has a single template From 8d9789c54a33b4df9d2f6accfe0e091703bd3ebe Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 14 Mar 2024 15:15:21 +0000 Subject: [PATCH 07/11] Add test for chat template dicts --- tests/test_tokenization_common.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 6c900fa72cd4..25357b0ab509 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -1118,6 +1118,24 @@ def test_chat_template(self): self.assertEqual(output, expected_output) # Test output is the same after reloading tokenizer.apply_chat_template(dummy_conversation, tokenize=True) # Check that no error raised + @require_jinja + def test_chat_template_dict(self): + dummy_template_1 = "{{'a'}}" + dummy_template_2 = "{{'b'}}" + dummy_conversation = [ + {"role": "system", "content": "system message"}, + {"role": "user", "content": "user message"}, + {"role": "assistant", "content": "assistant message"}, + ] + tokenizer = self.get_tokenizers()[0] + tokenizer.chat_template = {"template1": dummy_template_1, "template2": dummy_template_2} + output1 = tokenizer.apply_chat_template(dummy_conversation, chat_template=dummy_template_1, tokenize=False) + output1_via_dict = tokenizer.apply_chat_template(dummy_conversation, chat_template="template1", tokenize=False) + self.assertEqual(output1, output1_via_dict) + output2 = tokenizer.apply_chat_template(dummy_conversation, chat_template=dummy_template_2, tokenize=False) + output2_via_dict = tokenizer.apply_chat_template(dummy_conversation, chat_template="template2", tokenize=False) + self.assertEqual(output2, output2_via_dict) + def test_number_of_added_tokens(self): tokenizers = self.get_tokenizers(do_lower_case=False) for tokenizer in tokenizers: From 53ca0fb2f69ea3db02923366e76ece65f21cfd94 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 14 Mar 2024 15:39:52 +0000 Subject: [PATCH 08/11] Simplify the chat template dict test and apply it to all tokenizers in self.get_tokenizers() --- tests/test_tokenization_common.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 25357b0ab509..6158f305c2a2 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -1123,18 +1123,26 @@ def test_chat_template_dict(self): dummy_template_1 = "{{'a'}}" dummy_template_2 = "{{'b'}}" dummy_conversation = [ - {"role": "system", "content": "system message"}, {"role": "user", "content": "user message"}, - {"role": "assistant", "content": "assistant message"}, ] - tokenizer = self.get_tokenizers()[0] - tokenizer.chat_template = {"template1": dummy_template_1, "template2": dummy_template_2} - output1 = tokenizer.apply_chat_template(dummy_conversation, chat_template=dummy_template_1, tokenize=False) - output1_via_dict = tokenizer.apply_chat_template(dummy_conversation, chat_template="template1", tokenize=False) - self.assertEqual(output1, output1_via_dict) - output2 = tokenizer.apply_chat_template(dummy_conversation, chat_template=dummy_template_2, tokenize=False) - output2_via_dict = tokenizer.apply_chat_template(dummy_conversation, chat_template="template2", tokenize=False) - self.assertEqual(output2, output2_via_dict) + tokenizers = self.get_tokenizers() + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + tokenizer.chat_template = {"template1": dummy_template_1, "template2": dummy_template_2} + output1 = tokenizer.apply_chat_template( + dummy_conversation, chat_template=dummy_template_1, tokenize=False + ) + output1_via_dict = tokenizer.apply_chat_template( + dummy_conversation, chat_template="template1", tokenize=False + ) + self.assertEqual(output1, output1_via_dict) + output2 = tokenizer.apply_chat_template( + dummy_conversation, chat_template=dummy_template_2, tokenize=False + ) + output2_via_dict = tokenizer.apply_chat_template( + dummy_conversation, chat_template="template2", tokenize=False + ) + self.assertEqual(output2, output2_via_dict) def test_number_of_added_tokens(self): tokenizers = self.get_tokenizers(do_lower_case=False) From 6b71e6e8f0a8b84c17124c3bf07681f0985a4791 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 14 Mar 2024 17:50:22 +0000 Subject: [PATCH 09/11] Save chat template dicts as lists with fixed key names --- src/transformers/tokenization_utils_base.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index f21c1fe3dac3..388312df6d63 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1610,6 +1610,10 @@ def __init__(self, **kwargs): # Stores a Jinja template that formats chat histories into tokenizable strings self.chat_template = kwargs.pop("chat_template", None) + if isinstance(self.chat_template, (list, tuple)): + # Chat templates are stored as lists of dicts with fixed key names, + # we reconstruct that into a single dict while loading them. + self.chat_template = {template["name"]: template["template"] for template in self.chat_template} super().__init__(**kwargs) @@ -2449,7 +2453,12 @@ def save_pretrained( tokenizer_config.update(self.special_tokens_map) if self.chat_template is not None: - tokenizer_config["chat_template"] = self.chat_template + if isinstance(self.chat_template, dict): + # Chat template dicts are saved to the config as lists of dicts with fixed key names. + # They will be reconstructed as a single dict during loading. + tokenizer_config["chat_template"] = [{"name": k, "template": v} for k, v in self.chat_template.items()] + else: + tokenizer_config["chat_template"] = self.chat_template if len(self.init_inputs) > 0: tokenizer_config["init_inputs"] = copy.deepcopy(self.init_inputs) From e814006a3bb036a385c966f8f4994dca82ae279f Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 14 Mar 2024 17:56:26 +0000 Subject: [PATCH 10/11] Add test for serialization/reloading --- tests/test_tokenization_common.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 6158f305c2a2..dde0dc181c5e 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -1144,6 +1144,25 @@ def test_chat_template_dict(self): ) self.assertEqual(output2, output2_via_dict) + def test_chat_template_dict_saving(self): + dummy_template_1 = "{{'a'}}" + dummy_template_2 = "{{'b'}}" + tokenizers = self.get_tokenizers() + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + tokenizer.chat_template = {"template1": dummy_template_1, "template2": dummy_template_2} + with tempfile.TemporaryDirectory() as tmp_dir_name: + tokenizer.save_pretrained(tmp_dir_name) + config_dict = json.load(open(os.path.join(tmp_dir_name, "tokenizer_config.json"))) + # Assert that chat templates are correctly serialized as lists of dictionaries + self.assertEqual( + config_dict["chat_template"], + [{"name": "template1", "template": "{{'a'}}"}, {"name": "template2", "template": "{{'b'}}"}], + ) + new_tokenizer = tokenizer.from_pretrained(tmp_dir_name) + # Assert that the serialized list is correctly reconstructed as a single dict + self.assertEqual(new_tokenizer.chat_template, tokenizer.chat_template) + def test_number_of_added_tokens(self): tokenizers = self.get_tokenizers(do_lower_case=False) for tokenizer in tokenizers: From 97442a6b9dbf4a341b2dcf0e4ec58058663b9bd8 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 14 Mar 2024 17:57:51 +0000 Subject: [PATCH 11/11] Add require_jinja just to be safe, even though I don't think we use it --- tests/test_tokenization_common.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index dde0dc181c5e..bc169332d3fc 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -1144,6 +1144,7 @@ def test_chat_template_dict(self): ) self.assertEqual(output2, output2_via_dict) + @require_jinja def test_chat_template_dict_saving(self): dummy_template_1 = "{{'a'}}" dummy_template_2 = "{{'b'}}"