-
Notifications
You must be signed in to change notification settings - Fork 31.6k
Closed
Labels
Description
# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
image_to_overwrite = torch.full(
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
)
image_to_overwrite[batch_indices, text_to_overwrite] = False
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)I think the last line has something wrong in logic, look at the example
the input_ids is as follows, 32000 is the image token index, 0 is the padding index, we suppose num_patches=2
[[32000, 32000, 1, 2, 3],
[1, 32000, 2, 3, 0]]
then new_token_positions is:
[[1, 3, 4, 5, 6],
[0, 2, 3, 4, 5]]
nb_image_pad is: [0, 1]
before the last step, image_to_overwrite is:
[[True, True, True, True, False, False, False],
[False, True, True, False, False, False, True]]
after the last step, image_to_overwrite is:
[[True, True, True, True, False, False, False],
[False, False, True, False, False, False, True]]
however, the right result should be:
[[True, True, True, True, False, False, False],
[False, True, True, False, False, False, False]]I think the code is only for left padding , if we use right padding, there should be some modifications, and here is my code:
image_to_overwrite = torch.full(
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
)
image_to_overwrite[batch_indices, text_to_overwrite] = False
if left_padding:
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
else:
image_to_overwrite &= torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1 <= new_positions[:, -1:].to(target_device)