Skip to content

Commit

Permalink
Move bias initializations from private methods to constructors (#351)
Browse files Browse the repository at this point in the history
* Move bias initializations from private methods to constructors

* Cleanup header

* Fix docstrings
  • Loading branch information
zhiqwang authored Mar 9, 2022
1 parent 225759b commit 0345ed6
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 34 deletions.
3 changes: 2 additions & 1 deletion yolort/models/anchor_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2020, Zhiqiang Wang. All Rights Reserved.
# Copyright (c) 2020, yolort team. All rights reserved.

from typing import Tuple, List

import torch
Expand Down
54 changes: 28 additions & 26 deletions yolort/models/box_head.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2020, yolort team. All rights reserved.

import math
from typing import Tuple, List, Dict
from typing import Dict, List, Tuple

import torch
import torch.nn.functional as F
Expand All @@ -11,13 +12,18 @@


class YOLOHead(nn.Module):
def __init__(
self,
in_channels: List[int],
num_anchors: int,
strides: List[int],
num_classes: int,
):
"""
A regression and classification head for use in YOLO.
Args:
in_channels (List[int]): number of channels of the input feature
num_anchors (int): number of anchors to be predicted
strides (List[int]): number of strides of the anchors
num_classes (int): number of classes to be predicted
"""

def __init__(self, in_channels: List[int], num_anchors: int, strides: List[int], num_classes: int):

super().__init__()
if not isinstance(in_channels, list):
in_channels = [in_channels] * len(strides)
Expand All @@ -26,25 +32,21 @@ def __init__(
self.num_outputs = num_classes + 5 # number of outputs per anchor
self.strides = strides

self.head = nn.ModuleList(
head_blocks = nn.ModuleList(
nn.Conv2d(ch, self.num_outputs * self.num_anchors, 1) for ch in in_channels
) # output conv

self._initialize_biases() # Init weights, biases
)

def _initialize_biases(self, cf=None):
"""
Initialize biases into YOLOHead, cf is class frequency
Check section 3.3 in <https://arxiv.org/abs/1708.02002>
"""
for mi, s in zip(self.head, self.strides):
# Initialize biases into head blocks
for mi, s in zip(head_blocks, self.strides):
b = mi.bias.view(self.num_anchors, -1) # conv.bias(255) to (3,85)
# obj (8 objects per 640 image)
b.data[:, 4] += math.log(8 / (640 / s) ** 2)
# classes
b.data[:, 5:] += torch.log(cf / cf.sum()) if cf else math.log(0.6 / (self.num_classes - 0.99))
b.data[:, 5:] += math.log(0.6 / (self.num_classes - 0.999999))
mi.bias = nn.Parameter(b.view(-1), requires_grad=True)

self.head = head_blocks

def get_result_from_head(self, features: Tensor, idx: int) -> Tensor:
"""
This is equivalent to self.head[idx](features),
Expand Down Expand Up @@ -358,6 +360,12 @@ def _decode_pred_logits(pred_logits: Tensor):
class PostProcess(nn.Module):
"""
Performs Non-Maximum Suppression (NMS) on inference results
Args:
strides (List[int]): Strides of the AnchorGenerator.
score_thresh (float): Score threshold used for postprocessing the detections.
nms_thresh (float): NMS threshold used for postprocessing the detections.
detections_per_img (int): Number of best detections to keep after NMS.
"""

def __init__(
Expand All @@ -367,13 +375,7 @@ def __init__(
nms_thresh: float,
detections_per_img: int,
) -> None:
"""
Args:
strides (List[int]): Strides of the AnchorGenerator.
score_thresh (float): Score threshold used for postprocessing the detections.
nms_thresh (float): NMS threshold used for postprocessing the detections.
detections_per_img (int): Number of best detections to keep after NMS.
"""

super().__init__()
self.strides = strides
self.score_thresh = score_thresh
Expand Down
10 changes: 3 additions & 7 deletions yolort/utils/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,7 @@ def get_image_from_url(url: str, flags: int = 1) -> np.ndarray:
return image


def read_image_to_tensor(
image: np.ndarray,
is_half: bool = False,
) -> Tensor:
def read_image_to_tensor(image: np.ndarray, is_half: bool = False) -> Tensor:
"""
Parse an image to Tensor.
Expand All @@ -122,9 +119,8 @@ def read_image_to_tensor(
image = np.ascontiguousarray(image, dtype=np.float32) # uint8 to float32
image = np.transpose(image / 255.0, [2, 0, 1])

image = torch.from_numpy(image)
image = image.half() if is_half else image.float()
return image
_dtype = torch.float16 if is_half else torch.float32
return torch.from_numpy(image).to(dtype=_dtype)


def load_names(category_path):
Expand Down

0 comments on commit 0345ed6

Please sign in to comment.