From 4ea3e0ed730351da092c6d6371005ec0d6ff716a Mon Sep 17 00:00:00 2001 From: laurentd-lunit Date: Fri, 20 Sep 2024 06:03:59 +0000 Subject: [PATCH 1/8] [feat] add feature size check to avoid CUDA Runtime Error --- .../models/llava_next/modeling_llava_next.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index bb07ddfdaccd..ddda3c0d43a5 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -897,11 +897,18 @@ def forward( # TODO: @raushan retain only the new behavior after v4.47 else: - special_image_mask = ( - (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + n_image_tokens = (input_ids == self.config.image_token_index).sum().item() + n_image_features = image_features.shape[0] + if n_image_tokens == n_image_features: + special_image_mask = ( + (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + else: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) outputs = self.language_model( attention_mask=attention_mask, From 610bc7527b0ccd1e7082f674910efcb389b94941 Mon Sep 17 00:00:00 2001 From: laurentd-lunit Date: Thu, 26 Sep 2024 06:25:53 +0000 Subject: [PATCH 2/8] [minor] add error handling to all llava models --- .../models/llava/modeling_llava.py | 17 ++++++--- .../modeling_llava_next_video.py | 34 ++++++++++++----- .../modular_llava_next_video.py | 34 ++++++++++++----- .../modeling_llava_onevision.py | 24 +++++++----- .../video_llava/modeling_video_llava.py | 37 +++++++++++++------ .../models/vipllava/modeling_vipllava.py | 17 ++++++--- 6 files changed, 112 insertions(+), 51 deletions(-) diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 092008873d1e..df87dec07582 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -511,11 +511,18 @@ def forward( # TODO: @raushan retain only the new behavior after v4.47 else: - special_image_mask = ( - (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + n_image_tokens = (input_ids == self.config.image_token_index).sum().item() + n_image_features = image_features.shape[0] + if n_image_tokens == n_image_features: + special_image_mask = ( + (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + else: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) outputs = self.language_model( attention_mask=attention_mask, diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index 46b9b23bd66b..d732fcefdf75 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -979,17 +979,31 @@ def forward( # TODO: @raushan retain only the new behavior after v4.47 else: if image_features is not None: - special_image_mask = ( - (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + n_image_tokens = (input_ids == self.config.image_token_index).sum().item() + n_image_features = image_features.shape[0] + if n_image_tokens == n_image_features: + special_image_mask = ( + (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + else: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) if video_features is not None: - special_image_mask = ( - (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) - ) - video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) + n_video_tokens = (input_ids == self.config.video_token_index).sum().item() + n_video_features = video_features.shape[0] + if n_video_tokens == n_video_features: + special_image_mask = ( + (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) + ) + video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) + else: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) outputs = self.language_model( attention_mask=attention_mask, diff --git a/src/transformers/models/llava_next_video/modular_llava_next_video.py b/src/transformers/models/llava_next_video/modular_llava_next_video.py index f48056cfb97e..5e7bf97761eb 100644 --- a/src/transformers/models/llava_next_video/modular_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modular_llava_next_video.py @@ -484,17 +484,31 @@ def forward( # TODO: @raushan retain only the new behavior after v4.47 else: if image_features is not None: - special_image_mask = ( - (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + n_image_tokens = (input_ids == self.config.image_token_index).sum().item() + n_image_features = image_features.shape[0] + if n_image_tokens == n_image_features: + special_image_mask = ( + (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + else: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) if video_features is not None: - special_image_mask = ( - (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) - ) - video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) + n_video_tokens = (input_ids == self.config.video_token_index).sum().item() + n_video_features = video_features.shape[0] + if n_video_tokens == n_video_features: + special_image_mask = ( + (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) + ) + video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) + else: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) outputs = self.language_model( attention_mask=attention_mask, diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index c378ff09f1e4..4b0737023d1e 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -621,15 +621,21 @@ def forward( image_newline=self.image_newline, vision_aspect_ratio=vision_aspect_ratio, ) - - special_image_mask = ( - (input_ids == self.config.image_token_index) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + n_image_tokens = (input_ids == self.config.image_token_index).sum().item() + n_image_features = image_features.shape[0] + if n_image_tokens == n_image_features: + special_image_mask = ( + (input_ids == self.config.image_token_index) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + else: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) # Video are simply embedded and further pooled to decrease seq len if pixel_values_videos is not None: diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index 7c7cfec20959..80b9df25234c 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -620,18 +620,31 @@ def forward( # TODO: @raushan retain only the new behavior after v4.47 else: if image_outputs is not None: - special_image_mask = ( - (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - - if video_outputs is not None: - special_image_mask = ( - (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) - ) - video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) + n_image_tokens = (input_ids == self.config.image_token_index).sum().item() + n_image_features = image_features.shape[0] + if n_image_tokens == n_image_features: + special_image_mask = ( + (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + else: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + if video_features is not None: + n_video_tokens = (input_ids == self.config.video_token_index).sum().item() + n_video_features = video_features.shape[0] + if n_video_tokens == n_video_features: + special_image_mask = ( + (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) + ) + video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) + else: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) outputs = self.language_model( attention_mask=attention_mask, diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 95129d46bbd8..4fbcf8dd5ab3 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -506,11 +506,18 @@ def forward( # TODO: @raushan retain only the new behavior after v4.47 else: - special_image_mask = ( - (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + n_image_tokens = (input_ids == self.config.image_token_index).sum().item() + n_image_features = image_features.shape[0] + if n_image_tokens == n_image_features: + special_image_mask = ( + (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + else: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) outputs = self.language_model( attention_mask=attention_mask, From 353c610ea464c299e69dca38f2fe476ef3fe3089 Mon Sep 17 00:00:00 2001 From: laurentd-lunit Date: Mon, 30 Sep 2024 05:36:49 +0000 Subject: [PATCH 3/8] [minor] avoid nested if else --- .../models/llava/modeling_llava.py | 13 +++++----- .../models/llava_next/modeling_llava_next.py | 13 +++++----- .../modeling_llava_next_video.py | 26 +++++++++---------- .../modular_llava_next_video.py | 26 +++++++++---------- .../modeling_llava_onevision.py | 19 +++++++------- .../video_llava/modeling_video_llava.py | 26 +++++++++---------- .../models/vipllava/modeling_vipllava.py | 13 +++++----- 7 files changed, 63 insertions(+), 73 deletions(-) diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index df87dec07582..fe9eb6ea8ab6 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -513,16 +513,15 @@ def forward( else: n_image_tokens = (input_ids == self.config.image_token_index).sum().item() n_image_features = image_features.shape[0] - if n_image_tokens == n_image_features: - special_image_mask = ( - (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - else: + if n_image_tokens != n_image_features: raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) + special_image_mask = ( + (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) outputs = self.language_model( attention_mask=attention_mask, diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index ddda3c0d43a5..52c37df57aa1 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -899,16 +899,15 @@ def forward( else: n_image_tokens = (input_ids == self.config.image_token_index).sum().item() n_image_features = image_features.shape[0] - if n_image_tokens == n_image_features: - special_image_mask = ( - (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - else: + if n_image_tokens != n_image_features: raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) + special_image_mask = ( + (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) outputs = self.language_model( attention_mask=attention_mask, diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index d732fcefdf75..aad21e28a6b2 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -981,29 +981,27 @@ def forward( if image_features is not None: n_image_tokens = (input_ids == self.config.image_token_index).sum().item() n_image_features = image_features.shape[0] - if n_image_tokens == n_image_features: - special_image_mask = ( - (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - else: + if n_image_tokens != n_image_features: raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) + special_image_mask = ( + (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) if video_features is not None: n_video_tokens = (input_ids == self.config.video_token_index).sum().item() n_video_features = video_features.shape[0] - if n_video_tokens == n_video_features: - special_image_mask = ( - (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) - ) - video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) - else: + if n_video_tokens != n_video_features: raise ValueError( f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" ) + special_image_mask = ( + (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) + ) + video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) outputs = self.language_model( attention_mask=attention_mask, diff --git a/src/transformers/models/llava_next_video/modular_llava_next_video.py b/src/transformers/models/llava_next_video/modular_llava_next_video.py index 5e7bf97761eb..9df11d6e688c 100644 --- a/src/transformers/models/llava_next_video/modular_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modular_llava_next_video.py @@ -486,29 +486,27 @@ def forward( if image_features is not None: n_image_tokens = (input_ids == self.config.image_token_index).sum().item() n_image_features = image_features.shape[0] - if n_image_tokens == n_image_features: - special_image_mask = ( - (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - else: + if n_image_tokens != n_image_features: raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) + special_image_mask = ( + (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) if video_features is not None: n_video_tokens = (input_ids == self.config.video_token_index).sum().item() n_video_features = video_features.shape[0] - if n_video_tokens == n_video_features: - special_image_mask = ( - (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) - ) - video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) - else: + if n_video_tokens != n_video_features: raise ValueError( f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" ) + special_image_mask = ( + (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) + ) + video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) outputs = self.language_model( attention_mask=attention_mask, diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index 4b0737023d1e..1bf1e9dfa984 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -623,19 +623,18 @@ def forward( ) n_image_tokens = (input_ids == self.config.image_token_index).sum().item() n_image_features = image_features.shape[0] - if n_image_tokens == n_image_features: - special_image_mask = ( - (input_ids == self.config.image_token_index) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - else: + if n_image_tokens != n_image_features: raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) + special_image_mask = ( + (input_ids == self.config.image_token_index) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # Video are simply embedded and further pooled to decrease seq len if pixel_values_videos is not None: diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index 80b9df25234c..2508dbf652a1 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -622,29 +622,27 @@ def forward( if image_outputs is not None: n_image_tokens = (input_ids == self.config.image_token_index).sum().item() n_image_features = image_features.shape[0] - if n_image_tokens == n_image_features: - special_image_mask = ( - (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - else: + if n_image_tokens != n_image_features: raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) + special_image_mask = ( + (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) if video_features is not None: n_video_tokens = (input_ids == self.config.video_token_index).sum().item() n_video_features = video_features.shape[0] - if n_video_tokens == n_video_features: - special_image_mask = ( - (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) - ) - video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) - else: + if n_video_tokens != n_video_features: raise ValueError( f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" ) + special_image_mask = ( + (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) + ) + video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) outputs = self.language_model( attention_mask=attention_mask, diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 4fbcf8dd5ab3..e39ff4a76606 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -508,16 +508,15 @@ def forward( else: n_image_tokens = (input_ids == self.config.image_token_index).sum().item() n_image_features = image_features.shape[0] - if n_image_tokens == n_image_features: - special_image_mask = ( - (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - else: + if n_image_tokens != n_image_features: raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) + special_image_mask = ( + (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) outputs = self.language_model( attention_mask=attention_mask, From 53d407df920f251992ac56c409db89feb683eef4 Mon Sep 17 00:00:00 2001 From: laurentd-lunit Date: Mon, 30 Sep 2024 05:50:40 +0000 Subject: [PATCH 4/8] [minor] add error message to Qwen2-vl and chameleon --- .../models/chameleon/modeling_chameleon.py | 6 ++++++ .../models/qwen2_vl/modeling_qwen2_vl.py | 12 ++++++++++++ 2 files changed, 18 insertions(+) diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index c4eb1eade6e2..2bef9544dfb7 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -1343,6 +1343,12 @@ def forward( if pixel_values is not None: image_tokens = self.get_image_tokens(pixel_values) + n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum().item() + n_image_features = image_tokens.shape[0] + if n_image_tokens_in_text != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens_in_text}, features {n_image_features}" + ) special_image_mask = input_ids == self.vocabulary_mapping.image_token_id image_tokens = image_tokens.to(input_ids.device, input_ids.dtype) input_ids = input_ids.masked_scatter(special_image_mask, image_tokens) diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 938ec4d5e423..d27263320d4e 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -1689,6 +1689,12 @@ def forward( if pixel_values is not None: pixel_values = pixel_values.type(self.visual.get_dtype()) image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + n_image_features = image_embeds.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) @@ -1696,6 +1702,12 @@ def forward( if pixel_values_videos is not None: pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + n_video_tokens = (input_ids == self.config.video_token_id).sum().item() + n_video_features = video_embeds.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) video_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) From fc076402af1447977df4fc81443dc1699a0181ed Mon Sep 17 00:00:00 2001 From: laurentd-lunit Date: Mon, 30 Sep 2024 06:09:15 +0000 Subject: [PATCH 5/8] [fix] token dimension for check --- src/transformers/models/llava/modeling_llava.py | 2 +- src/transformers/models/video_llava/modeling_video_llava.py | 6 +++--- src/transformers/models/vipllava/modeling_vipllava.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index fe9eb6ea8ab6..6e33bc7ea16e 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -512,7 +512,7 @@ def forward( # TODO: @raushan retain only the new behavior after v4.47 else: n_image_tokens = (input_ids == self.config.image_token_index).sum().item() - n_image_features = image_features.shape[0] + n_image_features = image_features.shape[1] if n_image_tokens != n_image_features: raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index 2508dbf652a1..bb27672cf9cb 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -621,7 +621,7 @@ def forward( else: if image_outputs is not None: n_image_tokens = (input_ids == self.config.image_token_index).sum().item() - n_image_features = image_features.shape[0] + n_image_features = image_features.shape[1] if n_image_tokens != n_image_features: raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" @@ -631,9 +631,9 @@ def forward( ) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - if video_features is not None: + if video_outputs is not None: n_video_tokens = (input_ids == self.config.video_token_index).sum().item() - n_video_features = video_features.shape[0] + n_video_features = video_features.shape[1] if n_video_tokens != n_video_features: raise ValueError( f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index e39ff4a76606..93470653bd1c 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -507,7 +507,7 @@ def forward( # TODO: @raushan retain only the new behavior after v4.47 else: n_image_tokens = (input_ids == self.config.image_token_index).sum().item() - n_image_features = image_features.shape[0] + n_image_features = image_features.shape[1] if n_image_tokens != n_image_features: raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" From 9ac6689681e4835dc18d20dbdb2dc436bf9ede6b Mon Sep 17 00:00:00 2001 From: laurentd-lunit Date: Mon, 30 Sep 2024 06:17:13 +0000 Subject: [PATCH 6/8] [minor] add feature dim check for videos too --- .../models/llava_onevision/modeling_llava_onevision.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index 1bf1e9dfa984..67bae39bfc21 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -654,7 +654,12 @@ def forward( image_newline = self.image_newline[None, None, :].repeat(batch_size, 1, 1).to(video_features.device) video_features = torch.cat((video_features, image_newline), dim=1) video_features = video_features.flatten(0, 1) - + n_video_tokens = (input_ids == self.config.video_token_index).sum().item() + n_video_features = video_features.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) special_video_mask = ( (input_ids == self.config.video_token_index) .unsqueeze(-1) From edab0fc8694f873fff50eb5dee0403ee25201876 Mon Sep 17 00:00:00 2001 From: laurentd-lunit Date: Mon, 30 Sep 2024 06:40:00 +0000 Subject: [PATCH 7/8] [fix] dimension check --- src/transformers/models/llava/modeling_llava.py | 2 +- src/transformers/models/video_llava/modeling_video_llava.py | 4 ++-- src/transformers/models/vipllava/modeling_vipllava.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 6e33bc7ea16e..09f1a612e409 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -511,7 +511,7 @@ def forward( # TODO: @raushan retain only the new behavior after v4.47 else: - n_image_tokens = (input_ids == self.config.image_token_index).sum().item() + n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item() n_image_features = image_features.shape[1] if n_image_tokens != n_image_features: raise ValueError( diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index bb27672cf9cb..6c274e8e5f74 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -620,7 +620,7 @@ def forward( # TODO: @raushan retain only the new behavior after v4.47 else: if image_outputs is not None: - n_image_tokens = (input_ids == self.config.image_token_index).sum().item() + n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item() n_image_features = image_features.shape[1] if n_image_tokens != n_image_features: raise ValueError( @@ -632,7 +632,7 @@ def forward( image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) if video_outputs is not None: - n_video_tokens = (input_ids == self.config.video_token_index).sum().item() + n_video_tokens = (input_ids == self.config.video_token_index).sum(dim=-1)[0].item() n_video_features = video_features.shape[1] if n_video_tokens != n_video_features: raise ValueError( diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 93470653bd1c..66b2498a9d8e 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -506,7 +506,7 @@ def forward( # TODO: @raushan retain only the new behavior after v4.47 else: - n_image_tokens = (input_ids == self.config.image_token_index).sum().item() + n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item() n_image_features = image_features.shape[1] if n_image_tokens != n_image_features: raise ValueError( From bea8910b2db542bb43508e7b19754cbce0feaa87 Mon Sep 17 00:00:00 2001 From: laurentd-lunit Date: Mon, 30 Sep 2024 07:39:49 +0000 Subject: [PATCH 8/8] [fix] test reference values --- tests/models/llava/test_modeling_llava.py | 4 ++-- tests/models/vipllava/test_modeling_vipllava.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/models/llava/test_modeling_llava.py b/tests/models/llava/test_modeling_llava.py index e183c38a59f7..07415900bb93 100644 --- a/tests/models/llava/test_modeling_llava.py +++ b/tests/models/llava/test_modeling_llava.py @@ -118,8 +118,8 @@ def __init__( self.batch_size = 3 self.num_channels = 3 self.image_size = 336 - self.encoder_seq_length = 231 - self.num_image_tokens = 224 + self.encoder_seq_length = 232 + self.num_image_tokens = 225 self.seq_length = seq_length + self.num_image_tokens def get_config(self): diff --git a/tests/models/vipllava/test_modeling_vipllava.py b/tests/models/vipllava/test_modeling_vipllava.py index b12f2c30c774..862e144ecdd7 100644 --- a/tests/models/vipllava/test_modeling_vipllava.py +++ b/tests/models/vipllava/test_modeling_vipllava.py @@ -111,8 +111,8 @@ def __init__( self.batch_size = 3 self.num_channels = 3 self.image_size = 336 - self.encoder_seq_length = 231 - self.num_image_tokens = 224 + self.encoder_seq_length = 232 + self.num_image_tokens = 225 self.seq_length = seq_length + self.num_image_tokens def get_config(self):