Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/transformers/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,9 +502,12 @@ def save_pretrained(self, save_directory, push_to_hub: bool = False, **kwargs):
output_chat_template_file = os.path.join(save_directory, CHAT_TEMPLATE_NAME)

processor_dict = self.to_dict()
chat_template = processor_dict.pop("chat_template", None)
if chat_template is not None:
chat_template_json_string = json.dumps({"chat_template": chat_template}, indent=2, sort_keys=True) + "\n"
# Save `chat_template` in its own file. We can't get it from `processor_dict` as we popped it in `to_dict`
# to avoid serializing chat template in json config file. So let's get it from `self` directly
if self.chat_template is not None:
chat_template_json_string = (
json.dumps({"chat_template": self.chat_template}, indent=2, sort_keys=True) + "\n"
)
with open(output_chat_template_file, "w", encoding="utf-8") as writer:
writer.write(chat_template_json_string)
logger.info(f"chat template saved in {output_chat_template_file}")
Expand Down
43 changes: 41 additions & 2 deletions tests/models/llava/test_processor_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,57 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import tempfile
import unittest

from transformers import AutoTokenizer, LlamaTokenizerFast, LlavaProcessor
from transformers.testing_utils import require_vision
from transformers.utils import is_vision_available

from ...test_processing_common import ProcessorTesterMixin


if is_vision_available():
from transformers import AutoTokenizer, LlavaProcessor
from transformers import CLIPImageProcessor


@require_vision
class LlavaProcessorTest(unittest.TestCase):
class LlavaProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = LlavaProcessor

def setUp(self):
self.tmpdirname = tempfile.mkdtemp()

image_processor = CLIPImageProcessor()
tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b")
processor_kwargs = self.prepare_processor_dict()
processor = LlavaProcessor(image_processor, tokenizer, **processor_kwargs)
processor.save_pretrained(self.tmpdirname)

def get_tokenizer(self, **kwargs):
return LlavaProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer

def get_image_processor(self, **kwargs):
return LlavaProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor

def prepare_processor_dict(self):
return {"chat_template": "dummy_template"}

def test_processor_to_json_string(self):
processor = self.get_processor()
obj = json.loads(processor.to_json_string())
for key, value in self.prepare_processor_dict().items():
# chat templates are popped from dict
if key != "chat_template":
self.assertEqual(obj[key], value)
self.assertEqual(getattr(processor, key, None), value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also assert the chat template is popped from the dict.

Suggested change
for key, value in self.prepare_processor_dict().items():
# chat templates are popped from dict
if key != "chat_template":
self.assertEqual(obj[key], value)
self.assertEqual(getattr(processor, key, None), value)
for key, value in self.prepare_processor_dict().items():
# chat templates are popped from dict
self.assertFalse(key == "chat_template")
self.assertEqual(obj[key], value)
self.assertEqual(getattr(processor, key, None), value)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, no, the test fails because the processor_dict for llava has a chat_template key, and we use it in other tests for init and save the processor for ex. This test is same as the general one, with the exception that chat templates cannot pass self.assertEqual(obj[key], value) check

So we just want to test all other processor kwargs except chat template, which is tested separately. By other kwargs, I mean the ones which will be added soon

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with the exception that chat templates cannot pass self.assertEqual(obj[key], value) check

I see, I missed what was happening originally in the test. Isn't self.prepare_processor_dict().items() a bit redundant, as we force self.prepare_processor_dict() to only have one key, which is "chat_template" and so all of this logic is skipped?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, same way as almost all processors skip this test. For VLMs this test will become available when we enforce new processing logic for input expansion with image tokens. Until then, we can override it to prevent failing tests, instead of unitest.skip

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmmm, I don't think overriding to make it look like tests are passing is a great idea. Skipping is far better as it's easier to spot and track.

Part of the issue here is that this new behaviour still isn't being tested then, as we want to make sure that chat_template isn't in the processor_dict when saving out.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, we can add one more assert in test_chat_template_is_saved to check what is the content of processor_dict

Oke, I'll skip it then with a comment explaining why and that we need to stop skipping it at some point

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect :)


def test_chat_template_is_saved(self):
processor_loaded = self.processor_class.from_pretrained(self.tmpdirname)
processor_dict = self.prepare_processor_dict()
self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None))

def test_can_load_various_tokenizers(self):
for checkpoint in ["Intel/llava-gemma-2b", "llava-hf/llava-1.5-7b-hf"]:
processor = LlavaProcessor.from_pretrained(checkpoint)
Expand Down
45 changes: 43 additions & 2 deletions tests/models/llava_next/test_processor_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,61 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import tempfile
import unittest

import torch

from transformers import AutoProcessor, LlamaTokenizerFast, LlavaNextProcessor
from transformers.testing_utils import require_vision
from transformers.utils import is_vision_available

from ...test_processing_common import ProcessorTesterMixin


if is_vision_available():
from transformers import AutoProcessor
from transformers import CLIPImageProcessor


@require_vision
class LlavaProcessorTest(unittest.TestCase):
class LlavaNextProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = LlavaNextProcessor

def setUp(self):
self.tmpdirname = tempfile.mkdtemp()

image_processor = CLIPImageProcessor()
tokenizer = LlamaTokenizerFast.from_pretrained("huggyllama/llama-7b")
processor_kwargs = self.prepare_processor_dict()
processor = LlavaNextProcessor(image_processor, tokenizer, **processor_kwargs)
processor.save_pretrained(self.tmpdirname)

def get_tokenizer(self, **kwargs):
return LlavaNextProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer

def get_image_processor(self, **kwargs):
return LlavaNextProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor

def prepare_processor_dict(self):
return {"chat_template": "dummy_template"}

# Copied from tests.models.llava.test_processor_llava.LlavaProcessorTest.test_processor_to_json_string
def test_processor_to_json_string(self):
processor = self.get_processor()
obj = json.loads(processor.to_json_string())
for key, value in self.prepare_processor_dict().items():
# chat templates are popped from dict
if key != "chat_template":
self.assertEqual(obj[key], value)
self.assertEqual(getattr(processor, key, None), value)

# Copied from tests.models.llava.test_processor_llava.LlavaProcessorTest.test_chat_template_is_saved
def test_chat_template_is_saved(self):
processor_loaded = self.processor_class.from_pretrained(self.tmpdirname)
processor_dict = self.prepare_processor_dict()
self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None))

def test_chat_template(self):
processor = AutoProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-7b-hf")
expected_prompt = "USER: <image>\nWhat is shown in this image? ASSISTANT:"
Expand Down
25 changes: 23 additions & 2 deletions tests/models/llava_onevision/test_processing_llava_onevision.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import shutil
import tempfile
import unittest
Expand Down Expand Up @@ -40,9 +41,10 @@ def setUp(self):
image_processor = LlavaOnevisionImageProcessor()
video_processor = LlavaOnevisionVideoProcessor()
tokenizer = Qwen2TokenizerFast.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
processor_kwargs = self.prepare_processor_dict()

processor = LlavaOnevisionProcessor(
video_processor=video_processor, image_processor=image_processor, tokenizer=tokenizer
video_processor=video_processor, image_processor=image_processor, tokenizer=tokenizer, **processor_kwargs
)
processor.save_pretrained(self.tmpdirname)

Expand All @@ -52,9 +54,28 @@ def get_tokenizer(self, **kwargs):
def get_image_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor

def get_Video_processor(self, **kwargs):
def get_video_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).video_processor

def prepare_processor_dict(self):
return {"chat_template": "dummy_template"}

# Copied from tests.models.llava.test_processor_llava.LlavaProcessorTest.test_processor_to_json_string
def test_processor_to_json_string(self):
processor = self.get_processor()
obj = json.loads(processor.to_json_string())
for key, value in self.prepare_processor_dict().items():
# chat templates are popped from dict
if key != "chat_template":
self.assertEqual(obj[key], value)
self.assertEqual(getattr(processor, key, None), value)

# Copied from tests.models.llava.test_processor_llava.LlavaProcessorTest.test_chat_template_is_saved
def test_chat_template_is_saved(self):
processor_loaded = self.processor_class.from_pretrained(self.tmpdirname)
processor_dict = self.prepare_processor_dict()
self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None))

def tearDown(self):
shutil.rmtree(self.tmpdirname)

Expand Down