diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 18d7702a3d05..1b15161ed760 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -141,8 +141,12 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B embeddings = patch_embeds.flatten(2).transpose(1, 2) max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size - boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side) - position_ids = torch.full(size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0) + boundaries = torch.arange( + 1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side, device=pixel_values.device + ) + position_ids = torch.full( + size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0, device=pixel_values.device + ) for batch_idx, p_attn_mask in enumerate(patch_attention_mask): nb_patches_h = p_attn_mask[:, 0].sum() @@ -158,9 +162,8 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten() - position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids + position_ids[batch_idx][p_attn_mask.view(-1)] = pos_ids - position_ids = position_ids.to(self.position_embedding.weight.device) embeddings = embeddings + self.position_embedding(position_ids) return embeddings diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index a19099d14302..541658f2ff59 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -140,8 +140,12 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B embeddings = patch_embeds.flatten(2).transpose(1, 2) max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size - boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side) - position_ids = torch.full(size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0) + boundaries = torch.arange( + 1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side, device=pixel_values.device + ) + position_ids = torch.full( + size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0, device=pixel_values.device + ) for batch_idx, p_attn_mask in enumerate(patch_attention_mask): nb_patches_h = p_attn_mask[:, 0].sum() @@ -157,9 +161,8 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten() - position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids + position_ids[batch_idx][p_attn_mask.view(-1)] = pos_ids - position_ids = position_ids.to(self.position_embedding.weight.device) embeddings = embeddings + self.position_embedding(position_ids) return embeddings diff --git a/src/transformers/models/smolvlm/modeling_smolvlm.py b/src/transformers/models/smolvlm/modeling_smolvlm.py index 23db21c8d68a..a2295f12ce9f 100644 --- a/src/transformers/models/smolvlm/modeling_smolvlm.py +++ b/src/transformers/models/smolvlm/modeling_smolvlm.py @@ -135,8 +135,12 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B embeddings = patch_embeds.flatten(2).transpose(1, 2) max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size - boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side) - position_ids = torch.full(size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0) + boundaries = torch.arange( + 1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side, device=pixel_values.device + ) + position_ids = torch.full( + size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0, device=pixel_values.device + ) for batch_idx, p_attn_mask in enumerate(patch_attention_mask): nb_patches_h = p_attn_mask[:, 0].sum() @@ -152,9 +156,8 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten() - position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids + position_ids[batch_idx][p_attn_mask.view(-1)] = pos_ids - position_ids = position_ids.to(self.position_embedding.weight.device) embeddings = embeddings + self.position_embedding(position_ids) return embeddings