From a6734349a325866c230186bd166bbe769ce16303 Mon Sep 17 00:00:00 2001 From: dkloving Date: Fri, 21 May 2021 08:38:34 +0300 Subject: [PATCH] AnchorGenerator will use dtype of input feature_maps AnchorGenerator.grid_anchors had fp32 hard-coded which could result in forward pass returning mismatched datatypes, for example (fp32, fp16, fp16). Fix for #107 --- yolort/models/anchor_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/yolort/models/anchor_utils.py b/yolort/models/anchor_utils.py index dcb6aa21..dc5351fb 100644 --- a/yolort/models/anchor_utils.py +++ b/yolort/models/anchor_utils.py @@ -58,6 +58,7 @@ def set_xy_weights( def grid_anchors( self, grid_sizes: List[List[int]], + dtype: torch.dtype = torch.float32, device: torch.device = torch.device("cpu"), ) -> Tensor: @@ -67,8 +68,8 @@ def grid_anchors( grid_height, grid_width = size # For output anchor, compute [x_center, y_center, x_center, y_center] - shifts_x = torch.arange(0, grid_width, dtype=torch.float32, device=device) - shifts_y = torch.arange(0, grid_height, dtype=torch.float32, device=device) + shifts_x = torch.arange(0, grid_width, dtype=dtype, device=device) + shifts_y = torch.arange(0, grid_height, dtype=dtype, device=device) shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) shifts = torch.stack((shift_x, shift_y), dim=2) @@ -87,6 +88,6 @@ def forward(self, feature_maps: List[Tensor]) -> Tuple[Tensor, Tensor, Tensor]: wh_weights = self.set_wh_weights(grid_sizes, dtype, device) xy_weights = self.set_xy_weights(grid_sizes, dtype, device) - anchors = self.grid_anchors(grid_sizes, device) + anchors = self.grid_anchors(grid_sizes, dtype, device) return anchors, wh_weights, xy_weights