diff --git a/src/transformers/models/owlv2/modeling_owlv2.py b/src/transformers/models/owlv2/modeling_owlv2.py index e538d2b4d408..657c4e01e37b 100644 --- a/src/transformers/models/owlv2/modeling_owlv2.py +++ b/src/transformers/models/owlv2/modeling_owlv2.py @@ -18,7 +18,6 @@ from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple, Union -import numpy as np import torch import torch.utils.checkpoint from torch import Tensor, nn @@ -1312,27 +1311,27 @@ def __init__(self, config: Owlv2Config): self.sigmoid = nn.Sigmoid() self.sqrt_num_patches = config.vision_config.image_size // config.vision_config.patch_size + self.box_bias = self.compute_box_bias(self.sqrt_num_patches, self.owlv2.device) # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.normalize_grid_corner_coordinates - def normalize_grid_corner_coordinates(self, feature_map: torch.FloatTensor): - # Computes normalized xy corner coordinates from feature_map. - if not feature_map.ndim == 4: - raise ValueError("Expected input shape is [batch_size, num_patches, num_patches, hidden_dim]") - - device = feature_map.device - num_patches = feature_map.shape[1] - - # TODO: Remove numpy usage. - box_coordinates = np.stack( - np.meshgrid(np.arange(1, num_patches + 1), np.arange(1, num_patches + 1)), axis=-1 - ).astype(np.float32) - box_coordinates /= np.array([num_patches, num_patches], np.float32) + def normalize_grid_corner_coordinates(self, num_patches: int, device: torch.device) -> torch.Tensor: + """ + Computes normalized xy corner coordinates from feature_map. + Args: + num_patches: Number of patches in the feature map. + device: Device on which to create the tensor. + Returns: + box_coordinates: Normalized xy corner coordinates. + """ + box_coordinates = torch.stack( + torch.meshgrid(torch.arange(1, num_patches + 1), torch.arange(1, num_patches + 1)), dim=-1 + ).to(torch.float32) # Flatten (h, w, 2) -> (h*w, 2) + box_coordinates /= num_patches box_coordinates = box_coordinates.reshape( box_coordinates.shape[0] * box_coordinates.shape[1], box_coordinates.shape[2] - ) - box_coordinates = torch.from_numpy(box_coordinates).to(device) + ).to(device) return box_coordinates @@ -1351,20 +1350,26 @@ def objectness_predictor(self, image_features: torch.FloatTensor) -> torch.Float return objectness_logits # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.compute_box_bias - def compute_box_bias(self, feature_map: torch.FloatTensor) -> torch.FloatTensor: - # The box center is biased to its position on the feature grid - box_coordinates = self.normalize_grid_corner_coordinates(feature_map) + def compute_box_bias(self, num_patches: int, device: torch.device) -> torch.FloatTensor: + """ + Computes box bias for bounding box prediction (the box center is biased to its position on feature grid). + Args: + num_patches: Number of patches in the feature map. + device: Device on which to create the tensor. + Returns: + box_bias: bias term of the bounding box prediction. + """ + box_coordinates = self.normalize_grid_corner_coordinates(num_patches, device) box_coordinates = torch.clip(box_coordinates, 0.0, 1.0) # Unnormalize xy box_coord_bias = torch.log(box_coordinates + 1e-4) - torch.log1p(-box_coordinates + 1e-4) # The box size is biased to the patch size - box_size = torch.full_like(box_coord_bias, 1.0 / feature_map.shape[-2]) + box_size = torch.full_like(box_coord_bias, 1.0 / num_patches) box_size_bias = torch.log(box_size + 1e-4) - torch.log1p(-box_size + 1e-4) - - # Compute box bias box_bias = torch.cat([box_coord_bias, box_size_bias], dim=-1) + return box_bias # Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.box_predictor @@ -1387,7 +1392,7 @@ def box_predictor( pred_boxes = self.box_head(image_feats) # Compute the location of each token on the grid and use it to compute a bias for the bbox prediction - pred_boxes += self.compute_box_bias(feature_map) + pred_boxes += self.box_bias.to(feature_map.device) pred_boxes = self.sigmoid(pred_boxes) return pred_boxes diff --git a/src/transformers/models/owlvit/modeling_owlvit.py b/src/transformers/models/owlvit/modeling_owlvit.py index a06610a643bb..bcab4c8c6bd9 100644 --- a/src/transformers/models/owlvit/modeling_owlvit.py +++ b/src/transformers/models/owlvit/modeling_owlvit.py @@ -18,7 +18,6 @@ from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple, Union -import numpy as np import torch import torch.utils.checkpoint from torch import Tensor, nn @@ -1293,43 +1292,49 @@ def __init__(self, config: OwlViTConfig): self.sigmoid = nn.Sigmoid() self.sqrt_num_patches = config.vision_config.image_size // config.vision_config.patch_size + self.box_bias = self.compute_box_bias(self.sqrt_num_patches, self.owlvit.device) - def normalize_grid_corner_coordinates(self, feature_map: torch.FloatTensor): - # Computes normalized xy corner coordinates from feature_map. - if not feature_map.ndim == 4: - raise ValueError("Expected input shape is [batch_size, num_patches, num_patches, hidden_dim]") - - device = feature_map.device - num_patches = feature_map.shape[1] - - # TODO: Remove numpy usage. - box_coordinates = np.stack( - np.meshgrid(np.arange(1, num_patches + 1), np.arange(1, num_patches + 1)), axis=-1 - ).astype(np.float32) - box_coordinates /= np.array([num_patches, num_patches], np.float32) + def normalize_grid_corner_coordinates(self, num_patches: int, device: torch.device) -> torch.Tensor: + """ + Computes normalized xy corner coordinates from feature_map. + Args: + num_patches: Number of patches in the feature map. + device: Device on which to create the tensor. + Returns: + box_coordinates: Normalized xy corner coordinates. + """ + box_coordinates = torch.stack( + torch.meshgrid(torch.arange(1, num_patches + 1), torch.arange(1, num_patches + 1)), dim=-1 + ).to(torch.float32) # Flatten (h, w, 2) -> (h*w, 2) + box_coordinates /= num_patches box_coordinates = box_coordinates.reshape( box_coordinates.shape[0] * box_coordinates.shape[1], box_coordinates.shape[2] - ) - box_coordinates = torch.from_numpy(box_coordinates).to(device) + ).to(device) return box_coordinates - def compute_box_bias(self, feature_map: torch.FloatTensor) -> torch.FloatTensor: - # The box center is biased to its position on the feature grid - box_coordinates = self.normalize_grid_corner_coordinates(feature_map) + def compute_box_bias(self, num_patches: int, device: torch.device) -> torch.FloatTensor: + """ + Computes box bias for bounding box prediction (the box center is biased to its position on feature grid). + Args: + num_patches: Number of patches in the feature map. + device: Device on which to create the tensor. + Returns: + box_bias: bias term of the bounding box prediction. + """ + box_coordinates = self.normalize_grid_corner_coordinates(num_patches, device) box_coordinates = torch.clip(box_coordinates, 0.0, 1.0) # Unnormalize xy box_coord_bias = torch.log(box_coordinates + 1e-4) - torch.log1p(-box_coordinates + 1e-4) # The box size is biased to the patch size - box_size = torch.full_like(box_coord_bias, 1.0 / feature_map.shape[-2]) + box_size = torch.full_like(box_coord_bias, 1.0 / num_patches) box_size_bias = torch.log(box_size + 1e-4) - torch.log1p(-box_size + 1e-4) - - # Compute box bias box_bias = torch.cat([box_coord_bias, box_size_bias], dim=-1) + return box_bias def box_predictor( @@ -1351,7 +1356,7 @@ def box_predictor( pred_boxes = self.box_head(image_feats) # Compute the location of each token on the grid and use it to compute a bias for the bbox prediction - pred_boxes += self.compute_box_bias(feature_map) + pred_boxes += self.box_bias.to(feature_map.device) pred_boxes = self.sigmoid(pred_boxes) return pred_boxes