Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/transformers/models/fuyu/modeling_fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def forward(
if image_patches is not None:
patch_embeddings = self.get_image_features(image_patches)
patch_embeddings = torch.cat(patch_embeddings, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
special_image_mask = self.get_placeholder_tokens(
special_image_mask = self.get_placeholder_mask(
input_ids, inputs_embeds=inputs_embeds, image_features=patch_embeddings
)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, patch_embeddings)
Expand Down Expand Up @@ -379,6 +379,7 @@ def prepare_inputs_for_generation(
inputs_embeds=None,
image_patches=None,
image_patches_indices=None,
cache_position=None,
**kwargs,
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
Expand All @@ -390,10 +391,12 @@ def prepare_inputs_for_generation(
inputs_embeds=inputs_embeds,
image_patches=image_patches,
image_patches_indices=image_patches_indices,
cache_position=cache_position,
**kwargs,
)

if past_key_values is not None:
if cache_position[0] != 0:
# set image_patches and image_patches_indices to `None` for decoding stage
model_inputs["image_patches_indices"] = None
model_inputs["image_patches"] = None

Expand Down