Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 28 additions & 23 deletions src/transformers/models/owlv2/modeling_owlv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union

import numpy as np
import torch
import torch.utils.checkpoint
from torch import Tensor, nn
Expand Down Expand Up @@ -1312,27 +1311,27 @@ def __init__(self, config: Owlv2Config):
self.sigmoid = nn.Sigmoid()

self.sqrt_num_patches = config.vision_config.image_size // config.vision_config.patch_size
self.box_bias = self.compute_box_bias(self.sqrt_num_patches, self.owlv2.device)

# Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.normalize_grid_corner_coordinates
def normalize_grid_corner_coordinates(self, feature_map: torch.FloatTensor):
# Computes normalized xy corner coordinates from feature_map.
if not feature_map.ndim == 4:
raise ValueError("Expected input shape is [batch_size, num_patches, num_patches, hidden_dim]")

device = feature_map.device
num_patches = feature_map.shape[1]

# TODO: Remove numpy usage.
box_coordinates = np.stack(
np.meshgrid(np.arange(1, num_patches + 1), np.arange(1, num_patches + 1)), axis=-1
).astype(np.float32)
box_coordinates /= np.array([num_patches, num_patches], np.float32)
def normalize_grid_corner_coordinates(self, num_patches: int, device: torch.device) -> torch.Tensor:
"""
Computes normalized xy corner coordinates from feature_map.
Args:
num_patches: Number of patches in the feature map.
device: Device on which to create the tensor.
Returns:
box_coordinates: Normalized xy corner coordinates.
"""
box_coordinates = torch.stack(
torch.meshgrid(torch.arange(1, num_patches + 1), torch.arange(1, num_patches + 1)), dim=-1
).to(torch.float32)

# Flatten (h, w, 2) -> (h*w, 2)
box_coordinates /= num_patches
box_coordinates = box_coordinates.reshape(
box_coordinates.shape[0] * box_coordinates.shape[1], box_coordinates.shape[2]
)
box_coordinates = torch.from_numpy(box_coordinates).to(device)
).to(device)

return box_coordinates

Expand All @@ -1351,20 +1350,26 @@ def objectness_predictor(self, image_features: torch.FloatTensor) -> torch.Float
return objectness_logits

# Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.compute_box_bias
def compute_box_bias(self, feature_map: torch.FloatTensor) -> torch.FloatTensor:
# The box center is biased to its position on the feature grid
box_coordinates = self.normalize_grid_corner_coordinates(feature_map)
def compute_box_bias(self, num_patches: int, device: torch.device) -> torch.FloatTensor:
"""
Computes box bias for bounding box prediction (the box center is biased to its position on feature grid).
Args:
num_patches: Number of patches in the feature map.
device: Device on which to create the tensor.
Returns:
box_bias: bias term of the bounding box prediction.
"""
box_coordinates = self.normalize_grid_corner_coordinates(num_patches, device)
box_coordinates = torch.clip(box_coordinates, 0.0, 1.0)

# Unnormalize xy
box_coord_bias = torch.log(box_coordinates + 1e-4) - torch.log1p(-box_coordinates + 1e-4)

# The box size is biased to the patch size
box_size = torch.full_like(box_coord_bias, 1.0 / feature_map.shape[-2])
box_size = torch.full_like(box_coord_bias, 1.0 / num_patches)
box_size_bias = torch.log(box_size + 1e-4) - torch.log1p(-box_size + 1e-4)

# Compute box bias
box_bias = torch.cat([box_coord_bias, box_size_bias], dim=-1)

return box_bias

# Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.box_predictor
Expand All @@ -1387,7 +1392,7 @@ def box_predictor(
pred_boxes = self.box_head(image_feats)

# Compute the location of each token on the grid and use it to compute a bias for the bbox prediction
pred_boxes += self.compute_box_bias(feature_map)
pred_boxes += self.box_bias.to(feature_map.device)
pred_boxes = self.sigmoid(pred_boxes)
return pred_boxes

Expand Down
51 changes: 28 additions & 23 deletions src/transformers/models/owlvit/modeling_owlvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union

import numpy as np
import torch
import torch.utils.checkpoint
from torch import Tensor, nn
Expand Down Expand Up @@ -1293,43 +1292,49 @@ def __init__(self, config: OwlViTConfig):
self.sigmoid = nn.Sigmoid()

self.sqrt_num_patches = config.vision_config.image_size // config.vision_config.patch_size
self.box_bias = self.compute_box_bias(self.sqrt_num_patches, self.owlvit.device)

def normalize_grid_corner_coordinates(self, feature_map: torch.FloatTensor):
# Computes normalized xy corner coordinates from feature_map.
if not feature_map.ndim == 4:
raise ValueError("Expected input shape is [batch_size, num_patches, num_patches, hidden_dim]")

device = feature_map.device
num_patches = feature_map.shape[1]

# TODO: Remove numpy usage.
box_coordinates = np.stack(
np.meshgrid(np.arange(1, num_patches + 1), np.arange(1, num_patches + 1)), axis=-1
).astype(np.float32)
box_coordinates /= np.array([num_patches, num_patches], np.float32)
def normalize_grid_corner_coordinates(self, num_patches: int, device: torch.device) -> torch.Tensor:
"""
Computes normalized xy corner coordinates from feature_map.
Args:
num_patches: Number of patches in the feature map.
device: Device on which to create the tensor.
Returns:
box_coordinates: Normalized xy corner coordinates.
"""
box_coordinates = torch.stack(
torch.meshgrid(torch.arange(1, num_patches + 1), torch.arange(1, num_patches + 1)), dim=-1
).to(torch.float32)

# Flatten (h, w, 2) -> (h*w, 2)
box_coordinates /= num_patches
box_coordinates = box_coordinates.reshape(
box_coordinates.shape[0] * box_coordinates.shape[1], box_coordinates.shape[2]
)
box_coordinates = torch.from_numpy(box_coordinates).to(device)
).to(device)

return box_coordinates

def compute_box_bias(self, feature_map: torch.FloatTensor) -> torch.FloatTensor:
# The box center is biased to its position on the feature grid
box_coordinates = self.normalize_grid_corner_coordinates(feature_map)
def compute_box_bias(self, num_patches: int, device: torch.device) -> torch.FloatTensor:
"""
Computes box bias for bounding box prediction (the box center is biased to its position on feature grid).
Args:
num_patches: Number of patches in the feature map.
device: Device on which to create the tensor.
Returns:
box_bias: bias term of the bounding box prediction.
"""
box_coordinates = self.normalize_grid_corner_coordinates(num_patches, device)
box_coordinates = torch.clip(box_coordinates, 0.0, 1.0)

# Unnormalize xy
box_coord_bias = torch.log(box_coordinates + 1e-4) - torch.log1p(-box_coordinates + 1e-4)

# The box size is biased to the patch size
box_size = torch.full_like(box_coord_bias, 1.0 / feature_map.shape[-2])
box_size = torch.full_like(box_coord_bias, 1.0 / num_patches)
box_size_bias = torch.log(box_size + 1e-4) - torch.log1p(-box_size + 1e-4)

# Compute box bias
box_bias = torch.cat([box_coord_bias, box_size_bias], dim=-1)

return box_bias

def box_predictor(
Expand All @@ -1351,7 +1356,7 @@ def box_predictor(
pred_boxes = self.box_head(image_feats)

# Compute the location of each token on the grid and use it to compute a bias for the bbox prediction
pred_boxes += self.compute_box_bias(feature_map)
pred_boxes += self.box_bias.to(feature_map.device)
pred_boxes = self.sigmoid(pred_boxes)
return pred_boxes

Expand Down