diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index ee28c01189b4..1d86b128d31d 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -502,9 +502,12 @@ def save_pretrained(self, save_directory, push_to_hub: bool = False, **kwargs): output_chat_template_file = os.path.join(save_directory, CHAT_TEMPLATE_NAME) processor_dict = self.to_dict() - chat_template = processor_dict.pop("chat_template", None) - if chat_template is not None: - chat_template_json_string = json.dumps({"chat_template": chat_template}, indent=2, sort_keys=True) + "\n" + # Save `chat_template` in its own file. We can't get it from `processor_dict` as we popped it in `to_dict` + # to avoid serializing chat template in json config file. So let's get it from `self` directly + if self.chat_template is not None: + chat_template_json_string = ( + json.dumps({"chat_template": self.chat_template}, indent=2, sort_keys=True) + "\n" + ) with open(output_chat_template_file, "w", encoding="utf-8") as writer: writer.write(chat_template_json_string) logger.info(f"chat template saved in {output_chat_template_file}") diff --git a/tests/models/llava/test_processor_llava.py b/tests/models/llava/test_processor_llava.py index 5b05a8b92ea5..e62769e34509 100644 --- a/tests/models/llava/test_processor_llava.py +++ b/tests/models/llava/test_processor_llava.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json import shutil import tempfile import unittest @@ -32,11 +33,11 @@ class LlavaProcessorTest(ProcessorTesterMixin, unittest.TestCase): def setUp(self): self.tmpdirname = tempfile.mkdtemp() + image_processor = CLIPImageProcessor(do_center_crop=False) tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b") - - processor = LlavaProcessor(image_processor=image_processor, tokenizer=tokenizer) - + processor_kwargs = self.prepare_processor_dict() + processor = LlavaProcessor(image_processor, tokenizer, **processor_kwargs) processor.save_pretrained(self.tmpdirname) def get_tokenizer(self, **kwargs): @@ -48,6 +49,28 @@ def get_image_processor(self, **kwargs): def tearDown(self): shutil.rmtree(self.tmpdirname) + def prepare_processor_dict(self): + return {"chat_template": "dummy_template"} + + @unittest.skip( + "Skip because the model has no processor kwargs except for chat template and" + "chat template is saved as a separate file. Stop skipping this test when the processor" + "has new kwargs saved in config file." + ) + def test_processor_to_json_string(self): + pass + + def test_chat_template_is_saved(self): + processor_loaded = self.processor_class.from_pretrained(self.tmpdirname) + processor_dict_loaded = json.loads(processor_loaded.to_json_string()) + # chat templates aren't serialized to json in processors + self.assertFalse("chat_template" in processor_dict_loaded.keys()) + + # they have to be saved as separate file and loaded back from that file + # so we check if the same template is loaded + processor_dict = self.prepare_processor_dict() + self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None)) + def test_can_load_various_tokenizers(self): for checkpoint in ["Intel/llava-gemma-2b", "llava-hf/llava-1.5-7b-hf"]: processor = LlavaProcessor.from_pretrained(checkpoint) diff --git a/tests/models/llava_next/test_processor_llava_next.py b/tests/models/llava_next/test_processor_llava_next.py index c8b58ce7982f..450034f4151d 100644 --- a/tests/models/llava_next/test_processor_llava_next.py +++ b/tests/models/llava_next/test_processor_llava_next.py @@ -11,20 +11,65 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json +import tempfile import unittest import torch +from transformers import AutoProcessor, LlamaTokenizerFast, LlavaNextProcessor from transformers.testing_utils import require_vision from transformers.utils import is_vision_available +from ...test_processing_common import ProcessorTesterMixin + if is_vision_available(): - from transformers import AutoProcessor + from transformers import CLIPImageProcessor @require_vision -class LlavaProcessorTest(unittest.TestCase): +class LlavaNextProcessorTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = LlavaNextProcessor + + def setUp(self): + self.tmpdirname = tempfile.mkdtemp() + + image_processor = CLIPImageProcessor() + tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b") + processor_kwargs = self.prepare_processor_dict() + processor = LlavaNextProcessor(image_processor, tokenizer, **processor_kwargs) + processor.save_pretrained(self.tmpdirname) + + def get_tokenizer(self, **kwargs): + return LlavaNextProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer + + def get_image_processor(self, **kwargs): + return LlavaNextProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor + + def prepare_processor_dict(self): + return {"chat_template": "dummy_template"} + + @unittest.skip( + "Skip because the model has no processor kwargs except for chat template and" + "chat template is saved as a separate file. Stop skipping this test when the processor" + "has new kwargs saved in config file." + ) + def test_processor_to_json_string(self): + pass + + # Copied from tests.models.llava.test_processor_llava.LlavaProcessorTest.test_chat_template_is_saved + def test_chat_template_is_saved(self): + processor_loaded = self.processor_class.from_pretrained(self.tmpdirname) + processor_dict_loaded = json.loads(processor_loaded.to_json_string()) + # chat templates aren't serialized to json in processors + self.assertFalse("chat_template" in processor_dict_loaded.keys()) + + # they have to be saved as separate file and loaded back from that file + # so we check if the same template is loaded + processor_dict = self.prepare_processor_dict() + self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None)) + def test_chat_template(self): processor = AutoProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-7b-hf") expected_prompt = "USER: \nWhat is shown in this image? ASSISTANT:" diff --git a/tests/models/llava_onevision/test_processing_llava_onevision.py b/tests/models/llava_onevision/test_processing_llava_onevision.py index e045f2ba7f0b..f747c18250b6 100644 --- a/tests/models/llava_onevision/test_processing_llava_onevision.py +++ b/tests/models/llava_onevision/test_processing_llava_onevision.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json import shutil import tempfile import unittest @@ -40,9 +41,10 @@ def setUp(self): image_processor = LlavaOnevisionImageProcessor() video_processor = LlavaOnevisionVideoProcessor() tokenizer = Qwen2TokenizerFast.from_pretrained("Qwen/Qwen2-0.5B-Instruct") + processor_kwargs = self.prepare_processor_dict() processor = LlavaOnevisionProcessor( - video_processor=video_processor, image_processor=image_processor, tokenizer=tokenizer + video_processor=video_processor, image_processor=image_processor, tokenizer=tokenizer, **processor_kwargs ) processor.save_pretrained(self.tmpdirname) @@ -52,9 +54,32 @@ def get_tokenizer(self, **kwargs): def get_image_processor(self, **kwargs): return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor - def get_Video_processor(self, **kwargs): + def get_video_processor(self, **kwargs): return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).video_processor + def prepare_processor_dict(self): + return {"chat_template": "dummy_template"} + + @unittest.skip( + "Skip because the model has no processor kwargs except for chat template and" + "chat template is saved as a separate file. Stop skipping this test when the processor" + "has new kwargs saved in config file." + ) + def test_processor_to_json_string(self): + pass + + # Copied from tests.models.llava.test_processor_llava.LlavaProcessorTest.test_chat_template_is_saved + def test_chat_template_is_saved(self): + processor_loaded = self.processor_class.from_pretrained(self.tmpdirname) + processor_dict_loaded = json.loads(processor_loaded.to_json_string()) + # chat templates aren't serialized to json in processors + self.assertFalse("chat_template" in processor_dict_loaded.keys()) + + # they have to be saved as separate file and loaded back from that file + # so we check if the same template is loaded + processor_dict = self.prepare_processor_dict() + self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None)) + def tearDown(self): shutil.rmtree(self.tmpdirname)