Skip to content

Commit

Permalink
AnchorGenerator will use dtype of input feature_maps (#111)
Browse files Browse the repository at this point in the history
AnchorGenerator.grid_anchors had fp32 hard-coded which could result in forward pass returning mismatched datatypes, for example (fp32, fp16).
  • Loading branch information
dkloving authored May 21, 2021
1 parent 218c428 commit b0af4a1
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions yolort/models/anchor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def set_xy_weights(
def grid_anchors(
self,
grid_sizes: List[List[int]],
dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu"),
) -> Tensor:

Expand All @@ -67,8 +68,8 @@ def grid_anchors(
grid_height, grid_width = size

# For output anchor, compute [x_center, y_center, x_center, y_center]
shifts_x = torch.arange(0, grid_width, dtype=torch.float32, device=device)
shifts_y = torch.arange(0, grid_height, dtype=torch.float32, device=device)
shifts_x = torch.arange(0, grid_width, dtype=dtype, device=device)
shifts_y = torch.arange(0, grid_height, dtype=dtype, device=device)
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)

shifts = torch.stack((shift_x, shift_y), dim=2)
Expand All @@ -87,6 +88,6 @@ def forward(self, feature_maps: List[Tensor]) -> Tuple[Tensor, Tensor, Tensor]:

wh_weights = self.set_wh_weights(grid_sizes, dtype, device)
xy_weights = self.set_xy_weights(grid_sizes, dtype, device)
anchors = self.grid_anchors(grid_sizes, device)
anchors = self.grid_anchors(grid_sizes, dtype, device)

return anchors, wh_weights, xy_weights

0 comments on commit b0af4a1

Please sign in to comment.