diff --git a/src/transformers/models/siglip/modeling_siglip.py b/src/transformers/models/siglip/modeling_siglip.py index 507e0768a226..9d9bbcb90c0d 100644 --- a/src/transformers/models/siglip/modeling_siglip.py +++ b/src/transformers/models/siglip/modeling_siglip.py @@ -308,7 +308,8 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor: _, _, height, width = pixel_values.shape - patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] embeddings = patch_embeds.flatten(2).transpose(1, 2) if interpolate_pos_encoding: