Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions src/transformers/models/llava_next/modeling_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn

from ...activations import ACT2FN
Expand Down Expand Up @@ -133,14 +132,14 @@ def unpad_image(tensor, original_size):

if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = min(math.ceil(original_height * scale_factor), current_height)
padding, r = divmod(current_height - new_height, 2)
unpadded_tensor = tensor[:, padding : current_height - (padding + r), :]
new_height = int(round(original_height * scale_factor, 7))
padding = (current_height - new_height) // 2
unpadded_tensor = tensor[:, padding : current_height - padding, :]
else:
scale_factor = current_height / original_height
new_width = min(math.ceil(original_width * scale_factor), current_width)
padding, r = divmod(current_width - new_width, 2)
unpadded_tensor = tensor[:, :, padding : current_width - (padding + r)]
new_width = int(round(original_width * scale_factor, 7))
padding = (current_width - new_width) // 2
unpadded_tensor = tensor[:, :, padding : current_width - padding]

return unpadded_tensor

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,14 +304,14 @@ def unpad_image(tensor, original_size):

if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = min(math.ceil(original_height * scale_factor), current_height)
padding, r = divmod(current_height - new_height, 2)
unpadded_tensor = tensor[:, padding : current_height - (padding + r), :]
new_height = int(round(original_height * scale_factor, 7))
padding = (current_height - new_height) // 2
unpadded_tensor = tensor[:, padding : current_height - padding, :]
else:
scale_factor = current_height / original_height
new_width = min(math.ceil(original_width * scale_factor), current_width)
padding, r = divmod(current_width - new_width, 2)
unpadded_tensor = tensor[:, :, padding : current_width - (padding + r)]
new_width = int(round(original_width * scale_factor, 7))
padding = (current_width - new_width) // 2
unpadded_tensor = tensor[:, :, padding : current_width - padding]

return unpadded_tensor

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from typing import List, Optional, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn

from transformers.models.llava_next.modeling_llava_next import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,7 @@ def preprocess(
image,
image_grid_pinpoints,
size=size_tuple,
patch_size=size["height"],
patch_size=size_tuple[0],
resample=resample,
data_format=input_data_format,
input_data_format=input_data_format,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,14 +284,14 @@ def unpad_image(tensor, original_size):

if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = min(math.ceil(original_height * scale_factor), current_height)
padding, r = divmod(current_height - new_height, 2)
unpadded_tensor = tensor[:, padding : current_height - (padding + r), :]
new_height = int(round(original_height * scale_factor, 7))
padding = (current_height - new_height) // 2
unpadded_tensor = tensor[:, padding : current_height - padding, :]
else:
scale_factor = current_height / original_height
new_width = min(math.ceil(original_width * scale_factor), current_width)
padding, r = divmod(current_width - new_width, 2)
unpadded_tensor = tensor[:, :, padding : current_width - (padding + r)]
new_width = int(round(original_width * scale_factor, 7))
padding = (current_width - new_width) // 2
unpadded_tensor = tensor[:, :, padding : current_width - padding]

return unpadded_tensor

Expand Down
31 changes: 20 additions & 11 deletions tests/models/llava_next/test_modeling_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
if is_torch_available():
import torch

from transformers.models.llava_next.modeling_llava_next import image_size_to_num_patches, unpad_image
from transformers.models.llava_next.modeling_llava_next import image_size_to_num_patches


if is_vision_available():
Expand Down Expand Up @@ -298,18 +298,27 @@ def test_mismatching_num_image_tokens(self):
image_sizes = torch.cat([image_sizes, image_sizes], dim=0)
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes)

def test_unpad_image(self):
original_size = (400, 400)
def test_odd_sized_image(self):
# prepare model configuration
config = self.model_tester.get_config()

# Test case width is padded
pixel_values = floats_tensor([3, 400, 601])
unpadded_tensor = unpad_image(pixel_values, original_size)
self.assertEqual(unpadded_tensor.shape[1:], original_size)
# prepare input
num_image_tokens = 24
pixel_values = floats_tensor([1, 5, 3, config.vision_config.image_size, config.vision_config.image_size])
input_ids = ids_tensor([1, 64], config.text_config.vocab_size - 2) + 2
input_ids[:, :num_image_tokens] = config.image_token_index
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device)
inputs_dict = {
"pixel_values": pixel_values,
"image_sizes": torch.tensor([[13, 16]]), # odd-sized image
"input_ids": input_ids,
"attention_mask": attention_mask,
}

# Test case height is padded
pixel_values = floats_tensor([3, 503, 400])
unpadded_tensor = unpad_image(pixel_values, original_size)
self.assertEqual(unpadded_tensor.shape[1:], original_size)
# forward with odd-sized image input
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
model(**inputs_dict)

@parameterized.expand(
[
Expand Down
19 changes: 14 additions & 5 deletions tests/models/llava_next/test_processor_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import shutil
import tempfile
import unittest

import torch

from transformers import AutoProcessor, LlamaTokenizerFast, LlavaNextProcessor
from transformers import LlamaTokenizerFast, LlavaNextProcessor
from transformers.testing_utils import (
require_vision,
)
Expand Down Expand Up @@ -52,6 +54,10 @@ def get_tokenizer(self, **kwargs):
def get_image_processor(self, **kwargs):
return LlavaNextProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor

@classmethod
def tearDownClass(cls):
shutil.rmtree(cls.tmpdirname, ignore_errors=True)

@staticmethod
def prepare_processor_dict():
return {
Expand All @@ -73,13 +79,16 @@ def test_chat_template_is_saved(self):
self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None))

def test_image_token_filling(self):
processor = AutoProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-7b-hf")
processor = self.processor_class.from_pretrained(self.tmpdirname)
processor.patch_size = 14
processor.vision_feature_select_strategy = "default"
processor.image_processor.crop_size = {"height": 336, "width": 336}
processor.image_processor.size = {"shortest_edge": 336}
processor.image_processor.image_grid_pinpoints = [[672, 336]]
# Important to check with non square image
image = torch.randint(0, 2, (3, 500, 316))
expected_image_tokens = 1526
image_token_index = 32000
image = torch.randint(0, 2, (3, 503, 316))
expected_image_tokens = 1525
image_token_index = processor.image_token_id

messages = [
{
Expand Down
31 changes: 19 additions & 12 deletions tests/models/llava_next_video/test_modeling_llava_next_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@
if is_torch_available():
import torch

from transformers.models.llava_next_video.modeling_llava_next_video import unpad_image


if is_vision_available():
from PIL import Image
Expand Down Expand Up @@ -314,18 +312,27 @@ def test_mismatching_num_image_tokens(self):
image_sizes = torch.cat([image_sizes, image_sizes], dim=0)
_ = model(input_ids=input_ids, pixel_values=pixel_values, image_sizes=image_sizes)

def test_unpad_image(self):
original_size = (400, 400)
def test_odd_sized_image(self):
# prepare model configuration
config = self.model_tester.get_config()

# Test case width is padded
pixel_values = floats_tensor([3, 400, 601])
unpadded_tensor = unpad_image(pixel_values, original_size)
self.assertEqual(unpadded_tensor.shape[1:], original_size)
# prepare input
num_image_tokens = 24
pixel_values = floats_tensor([1, 5, 3, config.vision_config.image_size, config.vision_config.image_size])
input_ids = ids_tensor([1, 64], config.text_config.vocab_size - 2) + 2
input_ids[:, :num_image_tokens] = config.image_token_index
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device)
inputs_dict = {
"pixel_values": pixel_values,
"image_sizes": torch.tensor([[13, 16]]), # odd-sized image
"input_ids": input_ids,
"attention_mask": attention_mask,
}

# Test case height is padded
pixel_values = floats_tensor([3, 503, 400])
unpadded_tensor = unpad_image(pixel_values, original_size)
self.assertEqual(unpadded_tensor.shape[1:], original_size)
# forward with odd-sized image input
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
model(**inputs_dict)

@parameterized.expand(
[
Expand Down
37 changes: 34 additions & 3 deletions tests/models/llava_next_video/test_processor_llava_next_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import tempfile
import unittest

import torch

from transformers import AutoProcessor, LlamaTokenizerFast, LlavaNextVideoProcessor
from transformers.testing_utils import require_vision
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
Expand Down Expand Up @@ -63,6 +65,10 @@ def get_image_processor(self, **kwargs):
def get_video_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).video_processor

@classmethod
def tearDownClass(cls):
shutil.rmtree(cls.tmpdirname, ignore_errors=True)

@classmethod
def prepare_processor_dict(cls):
return {
Expand All @@ -84,6 +90,31 @@ def test_chat_template_is_saved(self):
processor_dict = self.prepare_processor_dict()
self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None))

@classmethod
def tearDownClass(cls):
shutil.rmtree(cls.tmpdirname, ignore_errors=True)
def test_image_token_filling(self):
processor = self.processor_class.from_pretrained(self.tmpdirname)
processor.patch_size = 14
processor.vision_feature_select_strategy = "default"
processor.image_processor.crop_size = {"height": 336, "width": 336}
processor.image_processor.size = {"shortest_edge": 336}
processor.image_processor.image_grid_pinpoints = [[672, 336]]
# Important to check with non square image
image = torch.randint(0, 2, (3, 503, 316))
expected_image_tokens = 1525
image_token_index = processor.image_token_id

messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
inputs = processor(
text=[processor.apply_chat_template(messages)],
images=[image],
return_tensors="pt",
)
image_tokens = (inputs["input_ids"] == image_token_index).sum().item()
self.assertEqual(expected_image_tokens, image_tokens)
31 changes: 19 additions & 12 deletions tests/models/llava_onevision/test_modeling_llava_onevision.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@
if is_torch_available():
import torch

from transformers.models.llava_onevision.modeling_llava_onevision import unpad_image


if is_vision_available():
from PIL import Image
Expand Down Expand Up @@ -268,18 +266,27 @@ def test_inputs_embeds_matches_input_ids(self):
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
torch.testing.assert_close(out_embeds, out_ids)

def test_unpad_image(self):
original_size = (400, 400)
def test_odd_sized_image(self):
# prepare model configuration
config = self.model_tester.get_config()

# Test case width is padded
pixel_values = floats_tensor([3, 400, 601])
unpadded_tensor = unpad_image(pixel_values, original_size)
self.assertEqual(unpadded_tensor.shape[1:], original_size)
# prepare input
num_image_tokens = 10
pixel_values = floats_tensor([1, 2, 3, config.vision_config.image_size, config.vision_config.image_size])
input_ids = ids_tensor([1, 64], config.text_config.vocab_size - 2) + 2
input_ids[:, :num_image_tokens] = config.image_token_index
attention_mask = torch.ones(input_ids.shape, dtype=torch.long).to(torch_device)
inputs_dict = {
"pixel_values": pixel_values,
"image_sizes": torch.tensor([[13, 16]]), # odd-sized image
"input_ids": input_ids,
"attention_mask": attention_mask,
}

# Test case height is padded
pixel_values = floats_tensor([3, 503, 400])
unpadded_tensor = unpad_image(pixel_values, original_size)
self.assertEqual(unpadded_tensor.shape[1:], original_size)
# forward with odd-sized image input
for model_class in self.all_model_classes:
model = model_class(config).to(torch_device)
model(**inputs_dict)

@parameterized.expand(
[
Expand Down
33 changes: 33 additions & 0 deletions tests/models/llava_onevision/test_processor_llava_onevision.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import shutil
import tempfile
import unittest

import torch

from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available

Expand Down Expand Up @@ -90,3 +93,33 @@ def test_chat_template_is_saved(self):
# so we check if the same template is loaded
processor_dict = self.prepare_processor_dict()
self.assertTrue(processor_loaded.chat_template == processor_dict.get("chat_template", None))

def test_image_token_filling(self):
processor = self.processor_class.from_pretrained(self.tmpdirname)
processor.patch_size = 14
processor.vision_feature_select_strategy = "default"
processor.image_processor.crop_size = {"height": 336, "width": 336}
processor.image_processor.size = {"shortest_edge": 336}
processor.image_processor.image_grid_pinpoints = [[672, 336]]
processor.num_image_tokens = (processor.image_processor.size["shortest_edge"] // processor.patch_size) ** 2
# Important to check with non square image
image = torch.randint(0, 2, (3, 503, 316))
expected_image_tokens = 1525
image_token_index = processor.image_token_id

messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
inputs = processor(
text=[processor.apply_chat_template(messages)],
images=[image],
return_tensors="pt",
)
image_tokens = (inputs["input_ids"] == image_token_index).sum().item()
self.assertEqual(expected_image_tokens, image_tokens)