diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index fd76c0b11522..20dbfc317e13 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -1287,6 +1287,12 @@ def forward( if pixel_values is not None: image_tokens = self.get_image_tokens(pixel_values) + n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum().item() + n_image_features = image_tokens.shape[0] + if n_image_tokens_in_text != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens_in_text}, features {n_image_features}" + ) special_image_mask = input_ids == self.vocabulary_mapping.image_token_id image_tokens = image_tokens.to(input_ids.device, input_ids.dtype) input_ids = input_ids.masked_scatter(special_image_mask, image_tokens) diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index e793ca61c750..411b96f5c57a 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -518,6 +518,12 @@ def forward( # TODO: @raushan retain only the new behavior after v4.47 else: + n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item() + n_image_features = image_features.shape[1] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) special_image_mask = ( (input_ids == self.config.image_token_index) .unsqueeze(-1) diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index 705821c2b713..75dfcf5393ea 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -895,6 +895,12 @@ def forward( # TODO: @raushan retain only the new behavior after v4.47 else: + n_image_tokens = (input_ids == self.config.image_token_index).sum().item() + n_image_features = image_features.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) special_image_mask = ( (input_ids == self.config.image_token_index) .unsqueeze(-1) diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index 7df4cf20372b..30257b843978 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -967,6 +967,12 @@ def forward( # TODO: @raushan retain only the new behavior after v4.47 else: if image_features is not None: + n_image_tokens = (input_ids == self.config.image_token_index).sum().item() + n_image_features = image_features.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) special_image_mask = ( (input_ids == self.config.image_token_index) .unsqueeze(-1) @@ -976,6 +982,12 @@ def forward( image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) if video_features is not None: + n_video_tokens = (input_ids == self.config.video_token_index).sum().item() + n_video_features = video_features.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) special_image_mask = ( (input_ids == self.config.video_token_index) .unsqueeze(-1) diff --git a/src/transformers/models/llava_next_video/modular_llava_next_video.py b/src/transformers/models/llava_next_video/modular_llava_next_video.py index 4b6be407dcab..e7de66de444a 100644 --- a/src/transformers/models/llava_next_video/modular_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modular_llava_next_video.py @@ -482,6 +482,12 @@ def forward( # TODO: @raushan retain only the new behavior after v4.47 else: if image_features is not None: + n_image_tokens = (input_ids == self.config.image_token_index).sum().item() + n_image_features = image_features.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) special_image_mask = ( (input_ids == self.config.image_token_index) .unsqueeze(-1) @@ -491,6 +497,12 @@ def forward( image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) if video_features is not None: + n_video_tokens = (input_ids == self.config.video_token_index).sum().item() + n_video_features = video_features.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) special_image_mask = ( (input_ids == self.config.video_token_index) .unsqueeze(-1) diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index f65c0fe7cfb3..3eefb517b16d 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -619,7 +619,12 @@ def forward( image_newline=self.image_newline, vision_aspect_ratio=vision_aspect_ratio, ) - + n_image_tokens = (input_ids == self.config.image_token_index).sum().item() + n_image_features = image_features.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) special_image_mask = ( (input_ids == self.config.image_token_index) .unsqueeze(-1) @@ -647,7 +652,12 @@ def forward( image_newline = self.image_newline[None, None, :].repeat(batch_size, 1, 1).to(video_features.device) video_features = torch.cat((video_features, image_newline), dim=1) video_features = video_features.flatten(0, 1) - + n_video_tokens = (input_ids == self.config.video_token_index).sum().item() + n_video_features = video_features.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) special_video_mask = ( (input_ids == self.config.video_token_index) .unsqueeze(-1) diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 283e38d3a7d5..e014a6da6bb3 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -1710,6 +1710,12 @@ def forward( if pixel_values is not None: pixel_values = pixel_values.type(self.visual.get_dtype()) image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + n_image_features = image_embeds.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) image_mask = ( (input_ids == self.config.image_token_id) .unsqueeze(-1) @@ -1722,6 +1728,12 @@ def forward( if pixel_values_videos is not None: pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + n_video_tokens = (input_ids == self.config.video_token_id).sum().item() + n_video_features = video_embeds.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) video_mask = ( (input_ids == self.config.video_token_id) .unsqueeze(-1) diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index 5711433c368d..20fa0166b80c 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -618,6 +618,12 @@ def forward( # TODO: @raushan retain only the new behavior after v4.47 else: if image_outputs is not None: + n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item() + n_image_features = image_features.shape[1] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) special_image_mask = ( (input_ids == self.config.image_token_index) .unsqueeze(-1) @@ -626,8 +632,13 @@ def forward( ) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - if video_outputs is not None: + n_video_tokens = (input_ids == self.config.video_token_index).sum(dim=-1)[0].item() + n_video_features = video_features.shape[1] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) special_image_mask = ( (input_ids == self.config.video_token_index) .unsqueeze(-1) diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 26d92b9ac3dc..763482284767 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -511,6 +511,12 @@ def forward( # TODO: @raushan retain only the new behavior after v4.47 else: + n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item() + n_image_features = image_features.shape[1] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) special_image_mask = ( (input_ids == self.config.image_token_index) .unsqueeze(-1) diff --git a/tests/models/llava/test_modeling_llava.py b/tests/models/llava/test_modeling_llava.py index e183c38a59f7..07415900bb93 100644 --- a/tests/models/llava/test_modeling_llava.py +++ b/tests/models/llava/test_modeling_llava.py @@ -118,8 +118,8 @@ def __init__( self.batch_size = 3 self.num_channels = 3 self.image_size = 336 - self.encoder_seq_length = 231 - self.num_image_tokens = 224 + self.encoder_seq_length = 232 + self.num_image_tokens = 225 self.seq_length = seq_length + self.num_image_tokens def get_config(self): diff --git a/tests/models/vipllava/test_modeling_vipllava.py b/tests/models/vipllava/test_modeling_vipllava.py index b12f2c30c774..862e144ecdd7 100644 --- a/tests/models/vipllava/test_modeling_vipllava.py +++ b/tests/models/vipllava/test_modeling_vipllava.py @@ -111,8 +111,8 @@ def __init__( self.batch_size = 3 self.num_channels = 3 self.image_size = 336 - self.encoder_seq_length = 231 - self.num_image_tokens = 224 + self.encoder_seq_length = 232 + self.num_image_tokens = 225 self.seq_length = seq_length + self.num_image_tokens def get_config(self):