diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 50b3d4c6a895..0b2492fc7112 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -354,7 +354,12 @@ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, in (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device ) image_to_overwrite[batch_indices, text_to_overwrite] = False - image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) + if left_padding: + image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) + else: + mask = torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1 + padding_mask = mask <= new_token_positions[:, -1:].to(target_device) + image_to_overwrite &= padding_mask if image_to_overwrite.sum() != image_features.shape[:-1].numel(): raise ValueError( diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index 0fe89676b92d..a9bd8b745a6f 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -339,7 +339,12 @@ def _merge_input_ids_with_visual_features( # 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling image_to_overwrite = torch.full((batch_size, max_seq_len), True, dtype=torch.bool, device=inputs_embeds.device) image_to_overwrite[batch_indices, text_to_overwrite] = False - image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) + if left_padding: + image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) + else: + mask = torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1 + padding_mask = mask <= new_token_positions[:, -1:].to(target_device) + image_to_overwrite &= padding_mask if image_to_overwrite.sum() != visual_features.shape[:-1].numel(): visual_type = "videos" if num_frames == 8 else "images" diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index dd7baa34406f..987ae0ad0c61 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -350,7 +350,12 @@ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, in (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device ) image_to_overwrite[batch_indices, text_to_overwrite] = False - image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) + if left_padding: + image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) + else: + mask = torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1 + padding_mask = mask <= new_token_positions[:, -1:].to(target_device) + image_to_overwrite &= padding_mask if image_to_overwrite.sum() != image_features.shape[:-1].numel(): raise ValueError(