Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 20 additions & 21 deletions src/transformers/models/owlv2/modeling_owlv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

import warnings
from dataclasses import dataclass
from functools import lru_cache
from typing import Any, Dict, Optional, Tuple, Union

import numpy as np
import torch
import torch.utils.checkpoint
from torch import Tensor, nn
Expand Down Expand Up @@ -1312,27 +1312,22 @@ def __init__(self, config: Owlv2Config):
self.sigmoid = nn.Sigmoid()

self.sqrt_num_patches = config.vision_config.image_size // config.vision_config.patch_size
self.box_bias = self.compute_box_bias(self.sqrt_num_patches)

@staticmethod
# Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.normalize_grid_corner_coordinates
def normalize_grid_corner_coordinates(self, feature_map: torch.FloatTensor):
# Computes normalized xy corner coordinates from feature_map.
if not feature_map.ndim == 4:
raise ValueError("Expected input shape is [batch_size, num_patches, num_patches, hidden_dim]")
def normalize_grid_corner_coordinates(num_patches: int) -> torch.Tensor:
# Create grid coordinates using torch
x_coordinates = torch.arange(1, num_patches + 1, dtype=torch.float32)
y_coordinates = torch.arange(1, num_patches + 1, dtype=torch.float32)
xx, yy = torch.meshgrid(x_coordinates, y_coordinates, indexing="xy")

device = feature_map.device
num_patches = feature_map.shape[1]

# TODO: Remove numpy usage.
box_coordinates = np.stack(
np.meshgrid(np.arange(1, num_patches + 1), np.arange(1, num_patches + 1)), axis=-1
).astype(np.float32)
box_coordinates /= np.array([num_patches, num_patches], np.float32)
# Stack the coordinates and divide by num_patches
box_coordinates = torch.stack((xx, yy), dim=-1)
box_coordinates /= num_patches

# Flatten (h, w, 2) -> (h*w, 2)
box_coordinates = box_coordinates.reshape(
box_coordinates.shape[0] * box_coordinates.shape[1], box_coordinates.shape[2]
)
box_coordinates = torch.from_numpy(box_coordinates).to(device)
box_coordinates = box_coordinates.view(-1, 2)

return box_coordinates

Expand All @@ -1350,17 +1345,20 @@ def objectness_predictor(self, image_features: torch.FloatTensor) -> torch.Float
objectness_logits = objectness_logits[..., 0]
return objectness_logits

@lru_cache(maxsize=2)
# Copied from transformers.models.owlvit.modeling_owlvit.OwlViTForObjectDetection.compute_box_bias
def compute_box_bias(self, feature_map: torch.FloatTensor) -> torch.FloatTensor:
def compute_box_bias(self, num_patches: int, feature_map: Optional[torch.FloatTensor] = None) -> torch.Tensor:
if feature_map is not None:
raise ValueError("feature_map has been deprecated as an input. Please pass in num_patches instead")
# The box center is biased to its position on the feature grid
box_coordinates = self.normalize_grid_corner_coordinates(feature_map)
box_coordinates = self.normalize_grid_corner_coordinates(num_patches)
box_coordinates = torch.clip(box_coordinates, 0.0, 1.0)

# Unnormalize xy
box_coord_bias = torch.log(box_coordinates + 1e-4) - torch.log1p(-box_coordinates + 1e-4)

# The box size is biased to the patch size
box_size = torch.full_like(box_coord_bias, 1.0 / feature_map.shape[-2])
box_size = torch.full_like(box_coord_bias, 1.0 / num_patches)
box_size_bias = torch.log(box_size + 1e-4) - torch.log1p(-box_size + 1e-4)

# Compute box bias
Expand All @@ -1387,7 +1385,8 @@ def box_predictor(
pred_boxes = self.box_head(image_feats)

# Compute the location of each token on the grid and use it to compute a bias for the bbox prediction
pred_boxes += self.compute_box_bias(feature_map)
box_bias = self.box_bias.to(feature_map.device)
pred_boxes += box_bias
pred_boxes = self.sigmoid(pred_boxes)
return pred_boxes

Expand Down
41 changes: 20 additions & 21 deletions src/transformers/models/owlvit/modeling_owlvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

import warnings
from dataclasses import dataclass
from functools import lru_cache
from typing import Any, Dict, Optional, Tuple, Union

import numpy as np
import torch
import torch.utils.checkpoint
from torch import Tensor, nn
Expand Down Expand Up @@ -1293,39 +1293,37 @@ def __init__(self, config: OwlViTConfig):
self.sigmoid = nn.Sigmoid()

self.sqrt_num_patches = config.vision_config.image_size // config.vision_config.patch_size
self.box_bias = self.compute_box_bias(self.sqrt_num_patches)

def normalize_grid_corner_coordinates(self, feature_map: torch.FloatTensor):
# Computes normalized xy corner coordinates from feature_map.
if not feature_map.ndim == 4:
raise ValueError("Expected input shape is [batch_size, num_patches, num_patches, hidden_dim]")
@staticmethod
def normalize_grid_corner_coordinates(num_patches: int) -> torch.Tensor:
# Create grid coordinates using torch
x_coordinates = torch.arange(1, num_patches + 1, dtype=torch.float32)
y_coordinates = torch.arange(1, num_patches + 1, dtype=torch.float32)
xx, yy = torch.meshgrid(x_coordinates, y_coordinates, indexing="xy")

device = feature_map.device
num_patches = feature_map.shape[1]

# TODO: Remove numpy usage.
box_coordinates = np.stack(
np.meshgrid(np.arange(1, num_patches + 1), np.arange(1, num_patches + 1)), axis=-1
).astype(np.float32)
box_coordinates /= np.array([num_patches, num_patches], np.float32)
# Stack the coordinates and divide by num_patches
box_coordinates = torch.stack((xx, yy), dim=-1)
box_coordinates /= num_patches

# Flatten (h, w, 2) -> (h*w, 2)
box_coordinates = box_coordinates.reshape(
box_coordinates.shape[0] * box_coordinates.shape[1], box_coordinates.shape[2]
)
box_coordinates = torch.from_numpy(box_coordinates).to(device)
box_coordinates = box_coordinates.view(-1, 2)

return box_coordinates

def compute_box_bias(self, feature_map: torch.FloatTensor) -> torch.FloatTensor:
@lru_cache(maxsize=2)
def compute_box_bias(self, num_patches: int, feature_map: Optional[torch.FloatTensor] = None) -> torch.Tensor:
if feature_map is not None:
raise ValueError("feature_map has been deprecated as an input. Please pass in num_patches instead")
# The box center is biased to its position on the feature grid
box_coordinates = self.normalize_grid_corner_coordinates(feature_map)
box_coordinates = self.normalize_grid_corner_coordinates(num_patches)
box_coordinates = torch.clip(box_coordinates, 0.0, 1.0)

# Unnormalize xy
box_coord_bias = torch.log(box_coordinates + 1e-4) - torch.log1p(-box_coordinates + 1e-4)

# The box size is biased to the patch size
box_size = torch.full_like(box_coord_bias, 1.0 / feature_map.shape[-2])
box_size = torch.full_like(box_coord_bias, 1.0 / num_patches)
box_size_bias = torch.log(box_size + 1e-4) - torch.log1p(-box_size + 1e-4)

# Compute box bias
Expand All @@ -1351,7 +1349,8 @@ def box_predictor(
pred_boxes = self.box_head(image_feats)

# Compute the location of each token on the grid and use it to compute a bias for the bbox prediction
pred_boxes += self.compute_box_bias(feature_map)
box_bias = self.box_bias.to(feature_map.device)
pred_boxes += box_bias
pred_boxes = self.sigmoid(pred_boxes)
return pred_boxes

Expand Down