diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index d25cf5e2f2a1..97ad4c711651 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -147,8 +147,8 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B nb_patches_h = p_attn_mask[:, 0].sum() nb_patches_w = p_attn_mask[0].sum() - h_indices = torch.arange(nb_patches_h, device=pixel_values.device, dtype=pixel_values.dtype) - w_indices = torch.arange(nb_patches_w, device=pixel_values.device, dtype=pixel_values.dtype) + h_indices = torch.arange(nb_patches_h, device=position_ids.device, dtype=position_ids.dtype) + w_indices = torch.arange(nb_patches_w, device=position_ids.device, dtype=position_ids.dtype) fractional_coords_h = h_indices / nb_patches_h * (1 - 1e-6) fractional_coords_w = w_indices / nb_patches_w * (1 - 1e-6) diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index c2d41aac02d7..a19099d14302 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -147,8 +147,8 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B nb_patches_h = p_attn_mask[:, 0].sum() nb_patches_w = p_attn_mask[0].sum() - h_indices = torch.arange(nb_patches_h, device=pixel_values.device, dtype=pixel_values.dtype) - w_indices = torch.arange(nb_patches_w, device=pixel_values.device, dtype=pixel_values.dtype) + h_indices = torch.arange(nb_patches_h, device=position_ids.device, dtype=position_ids.dtype) + w_indices = torch.arange(nb_patches_w, device=position_ids.device, dtype=position_ids.dtype) fractional_coords_h = h_indices / nb_patches_h * (1 - 1e-6) fractional_coords_w = w_indices / nb_patches_w * (1 - 1e-6) diff --git a/src/transformers/models/smolvlm/modeling_smolvlm.py b/src/transformers/models/smolvlm/modeling_smolvlm.py index 745206868581..23db21c8d68a 100644 --- a/src/transformers/models/smolvlm/modeling_smolvlm.py +++ b/src/transformers/models/smolvlm/modeling_smolvlm.py @@ -142,8 +142,8 @@ def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.B nb_patches_h = p_attn_mask[:, 0].sum() nb_patches_w = p_attn_mask[0].sum() - h_indices = torch.arange(nb_patches_h, device=pixel_values.device, dtype=pixel_values.dtype) - w_indices = torch.arange(nb_patches_w, device=pixel_values.device, dtype=pixel_values.dtype) + h_indices = torch.arange(nb_patches_h, device=position_ids.device, dtype=position_ids.dtype) + w_indices = torch.arange(nb_patches_w, device=position_ids.device, dtype=position_ids.dtype) fractional_coords_h = h_indices / nb_patches_h * (1 - 1e-6) fractional_coords_w = w_indices / nb_patches_w * (1 - 1e-6)