Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/transformers/models/chameleon/modeling_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1287,6 +1287,12 @@ def forward(

if pixel_values is not None:
image_tokens = self.get_image_tokens(pixel_values)
n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum().item()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mmmm in general I don't mind, as this should help our users, but the .item() might break compile compatibility (well only full graph).

@McPatate that's where and when we would need to see how much we are losing from this small change ! 🤗 (FYI @LysandreJik )

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can alwayss wrap these in is_torchdynamo_compiling, same was a s we wrap all warnings/logging now in generation code. So we ask users to make sure the code works w/o compilation, to see all warning etc, and then compile the code which will not show the exact reason why/where this CUDA-side error was triggered

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay that makes sense. Just 🥶 to more checks, but this one is most probably cached should be alright

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thing is these is_compiling are unrelated to normal users ~-> expose them to unrelated codes

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i see what you mean. Yes, the processing should maybe check this, but we cannot perform any checks before getting image hidden states. My main idea was to bring the same check we had earlier in merge_inputs method, so that after moving to the new logic we still can trace down bugs related to shape mismatch easily, or let users track that down

Also we won't do the sum() and item() every forward, for generation it is only for prefill stage after which we'll have image states in the cache. But anyway, if you think this is too many checks (given we now support old and new logic in VLMs for a few minor releases), I am okay with not adding it. I don't see it as a major blocker or anything, more like a nice addition for users :D

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay let's add it then 🤗

n_image_features = image_tokens.shape[0]
if n_image_tokens_in_text != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens_in_text}, features {n_image_features}"
)
special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
image_tokens = image_tokens.to(input_ids.device, input_ids.dtype)
input_ids = input_ids.masked_scatter(special_image_mask, image_tokens)
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,12 @@ def forward(

# TODO: @raushan retain only the new behavior after v4.47
else:
n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item()
n_image_features = image_features.shape[1]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
Comment on lines +521 to +526
Copy link
Collaborator

@ArthurZucker ArthurZucker Oct 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know why we are adding this here as the processor is supposed to check this for non legacy path!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is supposed. There was only one edge case with llava-next which uses pad/unpad technique and since we used tensors in modeling, there were minor numerical inconsistencies

Right now it should work, but in general imo it's a good idea to help users pinpoint what went wrong in their code

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not in the forward pass IMO, we are adding extra processing, .sum and .item() as seen above, which are run for every single forward pass. biggest issue for me is duplicated work!

special_image_mask = (
(input_ids == self.config.image_token_index)
.unsqueeze(-1)
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/llava_next/modeling_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,12 @@ def forward(

# TODO: @raushan retain only the new behavior after v4.47
else:
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (
(input_ids == self.config.image_token_index)
.unsqueeze(-1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,12 @@ def forward(
# TODO: @raushan retain only the new behavior after v4.47
else:
if image_features is not None:
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (
(input_ids == self.config.image_token_index)
.unsqueeze(-1)
Expand All @@ -976,6 +982,12 @@ def forward(
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
if video_features is not None:
n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
n_video_features = video_features.shape[0]
if n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
special_image_mask = (
(input_ids == self.config.video_token_index)
.unsqueeze(-1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,12 @@ def forward(
# TODO: @raushan retain only the new behavior after v4.47
else:
if image_features is not None:
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (
(input_ids == self.config.image_token_index)
.unsqueeze(-1)
Expand All @@ -491,6 +497,12 @@ def forward(
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
if video_features is not None:
n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
n_video_features = video_features.shape[0]
if n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
special_image_mask = (
(input_ids == self.config.video_token_index)
.unsqueeze(-1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,12 @@ def forward(
image_newline=self.image_newline,
vision_aspect_ratio=vision_aspect_ratio,
)

n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (
(input_ids == self.config.image_token_index)
.unsqueeze(-1)
Expand Down Expand Up @@ -647,7 +652,12 @@ def forward(
image_newline = self.image_newline[None, None, :].repeat(batch_size, 1, 1).to(video_features.device)
video_features = torch.cat((video_features, image_newline), dim=1)
video_features = video_features.flatten(0, 1)

n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
n_video_features = video_features.shape[0]
if n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
special_video_mask = (
(input_ids == self.config.video_token_index)
.unsqueeze(-1)
Expand Down
12 changes: 12 additions & 0 deletions src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1710,6 +1710,12 @@ def forward(
if pixel_values is not None:
pixel_values = pixel_values.type(self.visual.get_dtype())
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
n_image_features = image_embeds.shape[0]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_mask = (
(input_ids == self.config.image_token_id)
.unsqueeze(-1)
Expand All @@ -1722,6 +1728,12 @@ def forward(
if pixel_values_videos is not None:
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
n_video_features = video_embeds.shape[0]
if n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
video_mask = (
(input_ids == self.config.video_token_id)
.unsqueeze(-1)
Expand Down
13 changes: 12 additions & 1 deletion src/transformers/models/video_llava/modeling_video_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,12 @@ def forward(
# TODO: @raushan retain only the new behavior after v4.47
else:
if image_outputs is not None:
n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item()
n_image_features = image_features.shape[1]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (
(input_ids == self.config.image_token_index)
.unsqueeze(-1)
Expand All @@ -626,8 +632,13 @@ def forward(
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)

if video_outputs is not None:
n_video_tokens = (input_ids == self.config.video_token_index).sum(dim=-1)[0].item()
n_video_features = video_features.shape[1]
if n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
special_image_mask = (
(input_ids == self.config.video_token_index)
.unsqueeze(-1)
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/vipllava/modeling_vipllava.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,12 @@ def forward(

# TODO: @raushan retain only the new behavior after v4.47
else:
n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item()
n_image_features = image_features.shape[1]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (
(input_ids == self.config.image_token_index)
.unsqueeze(-1)
Expand Down
4 changes: 2 additions & 2 deletions tests/models/llava/test_modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ def __init__(
self.batch_size = 3
self.num_channels = 3
self.image_size = 336
self.encoder_seq_length = 231
self.num_image_tokens = 224
self.encoder_seq_length = 232
self.num_image_tokens = 225
self.seq_length = seq_length + self.num_image_tokens

def get_config(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/models/vipllava/test_modeling_vipllava.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ def __init__(
self.batch_size = 3
self.num_channels = 3
self.image_size = 336
self.encoder_seq_length = 231
self.num_image_tokens = 224
self.encoder_seq_length = 232
self.num_image_tokens = 225
self.seq_length = seq_length + self.num_image_tokens

def get_config(self):
Expand Down