Skip to content
8 changes: 5 additions & 3 deletions src/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,7 @@ def forward(
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids, attention_mask, labels
)
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
else:
# Retrieve the first layer to inspect the logits and mask out the hidden states
# that are set to 0
Expand Down Expand Up @@ -506,6 +507,9 @@ def forward(

attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[
-target_length:
]

# TODO: @raushan retain only the new behavior after v4.47
else:
Expand Down Expand Up @@ -585,9 +589,7 @@ def prepare_inputs_for_generation(
**kwargs,
)

if legacy_processing:
model_inputs["pixel_values"] = pixel_values
elif cache_position[0] == 0:
if legacy_processing or cache_position[0] == 0:
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
model_inputs["pixel_values"] = pixel_values
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/llava/processing_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def __call__(
raise ValueError("Invalid input text. Please provide a string, or a list of strings")

# try to expand inputs in processing if we have the necessary parts
prompt_strings = text
if image_inputs.get("pixel_values") is not None:
if self.patch_size is not None and self.vision_feature_select_strategy is not None:
# Replace the image token with the expanded image token sequence
Expand All @@ -150,7 +151,6 @@ def __call__(
sample = sample.replace(self.image_token, self.image_token * num_image_tokens)
prompt_strings.append(sample)
else:
prompt_strings = text
logger.warning_once(
"Expanding inputs for image tokens in LLaVa should be done in processing. "
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
Expand Down
13 changes: 7 additions & 6 deletions src/transformers/models/llava_next/modeling_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,7 @@ def forward(
position_ids,
labels=labels,
)
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
else:
# Retrieve the first layer to inspect the logits and mask out the hidden states
# that are set to 0
Expand Down Expand Up @@ -877,6 +878,9 @@ def forward(
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[
-target_length:
]

# TODO: @raushan retain only the new behavior after v4.47
else:
Expand Down Expand Up @@ -956,12 +960,9 @@ def prepare_inputs_for_generation(
**kwargs,
)

if legacy_processing:
model_inputs["pixel_values"] = pixel_values
model_inputs["image_sizes"] = image_sizes
elif cache_position[0] == 0:
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
if legacy_processing or cache_position[0] == 0:
model_inputs["pixel_values"] = pixel_values
model_inputs["image_sizes"] = image_sizes

Expand Down
47 changes: 23 additions & 24 deletions src/transformers/models/llava_next/processing_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,30 +140,29 @@ def __call__(
elif not isinstance(text, list) and not isinstance(text[0], str):
raise ValueError("Invalid input text. Please provide a string, or a list of strings")

if self.patch_size is None or self.vision_feature_select_strategy is None:
prompt_strings = text
logger.warning_once(
"Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. "
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
)
# cannot infer image expansion length if no images are found
elif not image_inputs:
prompt_strings = text
else:
image_sizes = image_inputs["image_sizes"]
height, width = get_image_size(to_numpy_array(image_inputs["pixel_values"][0][0]))
prompt_strings = []
for image_size, sample in zip(image_sizes, text):
# Replace the image token with the expanded image token sequence
orig_height, orig_width = image_size
num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width)
if self.vision_feature_select_strategy == "default":
num_image_tokens -= 1

sample = sample.replace(self.image_token, self.image_token * num_image_tokens)
prompt_strings.append(sample)
prompt_strings = text
if image_inputs:
if self.patch_size is None or self.vision_feature_select_strategy is None:
logger.warning_once(
"Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. "
"Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly "
"with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. "
"Using processors without these attributes in the config is deprecated and will throw an error in v4.47."
Comment on lines +143 to +150
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was needed for cases with mutli-image inputs where we cannot be sure that number of image sizes is same as text. For ex, one text and two images

)
else:
image_sizes = iter(image_inputs["image_sizes"])
height, width = get_image_size(to_numpy_array(image_inputs["pixel_values"][0][0]))
prompt_strings = []
for sample in text:
while self.image_token in sample:
image_size = next(image_sizes)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this right? The previous logic implies len(image_sizes) == len(text) however this is implying we have the same number of image_sizes as the number of image tokens per sample * number of samples.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I verified twice to be sure. The number of images == number of image sizes. The new added test multiimage_expansion fails with the previous logic

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amyeroberts do you have any other concerns regarding this PR?

orig_height, orig_width = image_size
num_image_tokens = self._get_number_of_features(orig_height, orig_width, height, width)
if self.vision_feature_select_strategy == "default":
num_image_tokens -= 1
sample = sample.replace(self.image_token, "<placeholder>" * num_image_tokens, 1)
prompt_strings.append(sample)
prompt_strings = [sample.replace("<placeholder>", self.image_token) for sample in prompt_strings]

text_inputs = self.tokenizer(
prompt_strings,
Expand Down
Loading