From 8f1bd73ca5f90de9fee89df76207691fa122ba7a Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 21 Oct 2024 16:53:34 +0200 Subject: [PATCH 1/2] fix right pad llavas --- src/transformers/models/llava/modeling_llava.py | 6 +++++- src/transformers/models/video_llava/modeling_video_llava.py | 6 +++++- src/transformers/models/vipllava/modeling_vipllava.py | 6 +++++- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index e793ca61c750..5a4e52d5c905 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -349,7 +349,11 @@ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, in (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device ) image_to_overwrite[batch_indices, text_to_overwrite] = False - image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) + if left_padding: + image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) + else: + mask = torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1 <= new_token_positions[:, -1:] + image_to_overwrite &= mask.to(target_device) if image_to_overwrite.sum() != image_features.shape[:-1].numel(): raise ValueError( diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index 5711433c368d..2c7b31189057 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -348,7 +348,11 @@ def _merge_input_ids_with_visual_features( # 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling image_to_overwrite = torch.full((batch_size, max_seq_len), True, dtype=torch.bool, device=inputs_embeds.device) image_to_overwrite[batch_indices, text_to_overwrite] = False - image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) + if left_padding: + image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) + else: + mask = torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1 <= new_token_positions[:, -1:] + image_to_overwrite &= mask.to(target_device) if image_to_overwrite.sum() != visual_features.shape[:-1].numel(): visual_type = "videos" if num_frames == 8 else "images" diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 26d92b9ac3dc..d269c03fe776 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -348,7 +348,11 @@ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, in (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device ) image_to_overwrite[batch_indices, text_to_overwrite] = False - image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) + if left_padding: + image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) + else: + mask = torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1 <= new_token_positions[:, -1:] + image_to_overwrite &= mask.to(target_device) if image_to_overwrite.sum() != image_features.shape[:-1].numel(): raise ValueError( From e85e4670057cd3231542c6619f003f415df84d35 Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 23 Oct 2024 08:43:15 +0200 Subject: [PATCH 2/2] device mismatch --- src/transformers/models/llava/modeling_llava.py | 5 +++-- src/transformers/models/video_llava/modeling_video_llava.py | 5 +++-- src/transformers/models/vipllava/modeling_vipllava.py | 5 +++-- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 5a4e52d5c905..9e48a6316c84 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -352,8 +352,9 @@ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, in if left_padding: image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) else: - mask = torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1 <= new_token_positions[:, -1:] - image_to_overwrite &= mask.to(target_device) + mask = torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1 + padding_mask = mask <= new_token_positions[:, -1:].to(target_device) + image_to_overwrite &= padding_mask if image_to_overwrite.sum() != image_features.shape[:-1].numel(): raise ValueError( diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index 2c7b31189057..0f9651d735e0 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -351,8 +351,9 @@ def _merge_input_ids_with_visual_features( if left_padding: image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) else: - mask = torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1 <= new_token_positions[:, -1:] - image_to_overwrite &= mask.to(target_device) + mask = torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1 + padding_mask = mask <= new_token_positions[:, -1:].to(target_device) + image_to_overwrite &= padding_mask if image_to_overwrite.sum() != visual_features.shape[:-1].numel(): visual_type = "videos" if num_frames == 8 else "images" diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index d269c03fe776..b77716c17e6b 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -351,8 +351,9 @@ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, in if left_padding: image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) else: - mask = torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1 <= new_token_positions[:, -1:] - image_to_overwrite &= mask.to(target_device) + mask = torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1 + padding_mask = mask <= new_token_positions[:, -1:].to(target_device) + image_to_overwrite &= padding_mask if image_to_overwrite.sum() != image_features.shape[:-1].numel(): raise ValueError(