Skip to content

Commit

Permalink
use new ops in apache/mxnet#16215
Browse files Browse the repository at this point in the history
  • Loading branch information
Jerryzcn committed Oct 3, 2019
1 parent 145c061 commit a546c46
Show file tree
Hide file tree
Showing 13 changed files with 149 additions and 102 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ The following commands install the stable version of GluonCV and MXNet:
```bash
pip install gluoncv --upgrade
pip install mxnet-mkl --upgrade
# if cuda 9.2 is installed
pip install mxnet-cu92mkl --upgrade
# if cuda 10.1 is installed
pip install mxnet-cu101mkl --upgrade
```

**The latest stable version of GluonCV is 0.4 and depends on mxnet >= 1.4.0**
Expand All @@ -66,8 +66,8 @@ You may get access to latest features and bug fixes with the following commands
```bash
pip install gluoncv --pre --upgrade
pip install mxnet-mkl --pre --upgrade
# if cuda 9.2 is installed
pip install mxnet-cu92mkl --pre --upgrade
# if cuda 10.1 is installed
pip install mxnet-cu101mkl --pre --upgrade
```

There are multiple versions of MXNet pre-built package available. Please refer to [mxnet packages](https://gluon-crash-course.mxnet.io/mxnet_packages.html) if you need more details about MXNet versions.
Expand Down
34 changes: 21 additions & 13 deletions gluoncv/model_zoo/faster_rcnn/faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def __init__(self, features, top_features, classes, box_features=None,
clip=clip, nms_thresh=rpn_nms_thresh, train_pre_nms=rpn_train_pre_nms,
train_post_nms=rpn_train_post_nms, test_pre_nms=rpn_test_pre_nms,
test_post_nms=rpn_test_post_nms, min_size=rpn_min_size,
multi_level=self.num_stages > 1)
multi_level=self.num_stages > 1, per_level_nms=True)
self.sampler = RCNNTargetSampler(num_image=self._batch_size,
num_proposal=rpn_train_post_nms, num_sample=num_sample,
pos_iou_thresh=pos_iou_thresh, pos_ratio=pos_ratio,
Expand Down Expand Up @@ -292,16 +292,20 @@ def _pyramid_roi_feats(self, F, features, rpn_rois, roi_size, strides, roi_mode=
# rpn_rois = F.take(rpn_rois, roi_level_sorted_args, axis=0)
pooled_roi_feats = []
for i, l in enumerate(range(self._min_stage, max_stage + 1)):
# Pool features with all rois first, and then set invalid pooled features to zero,
# at last ele-wise add together to aggregate all features.
if roi_mode == 'pool':
# Pool features with all rois first, and then set invalid pooled features to zero,
# at last ele-wise add together to aggregate all features.
pooled_feature = F.ROIPooling(features[i], rpn_rois, roi_size, 1. / strides[i])
pooled_feature = F.where(roi_level == l, pooled_feature,
F.zeros_like(pooled_feature))
elif roi_mode == 'align':
pooled_feature = F.contrib.ROIAlign(features[i], rpn_rois, roi_size,
masked_rpn_rois = F.where(roi_level == l, rpn_rois, F.ones_like(rpn_rois) * -1.)
pooled_feature = F.contrib.ROIAlign(features[i], masked_rpn_rois, roi_size,
1. / strides[i], sample_ratio=2)
# pooled_feature = F.where(roi_level == l, pooled_feature,
# F.zeros_like(pooled_feature))
else:
raise ValueError("Invalid roi mode: {}".format(roi_mode))
pooled_feature = F.where(roi_level == l, pooled_feature, F.zeros_like(pooled_feature))
pooled_roi_feats.append(pooled_feature)
# Ele-wise add to aggregate all pooled features
pooled_roi_feats = F.ElementWiseSum(*pooled_roi_feats)
Expand All @@ -312,7 +316,7 @@ def _pyramid_roi_feats(self, F, features, rpn_rois, roi_size, strides, roi_mode=
return pooled_roi_feats

# pylint: disable=arguments-differ
def hybrid_forward(self, F, x, gt_box=None):
def hybrid_forward(self, F, x, gt_box=None, gt_label=None):
"""Forward Faster-RCNN network.
The behavior during training and inference is different.
Expand All @@ -322,7 +326,9 @@ def hybrid_forward(self, F, x, gt_box=None):
x : mxnet.nd.NDArray or mxnet.symbol
The network input tensor.
gt_box : type, only required during training
The ground-truth bbox tensor with shape (1, N, 4).
The ground-truth bbox tensor with shape (B, N, 4).
gt_label : type, only required during training
The ground-truth label tensor with shape (B, 1, 4).
Returns
-------
Expand Down Expand Up @@ -393,11 +399,13 @@ def _split(x, axis, num_outputs, squeeze_axis):

# no need to convert bounding boxes in training, just return
if autograd.is_training():
cls_targets, box_targets, box_masks = self.target_generator(rpn_box, samples, matches,
gt_label, gt_box)
if self._additional_output:
return (cls_pred, box_pred, rpn_box, samples, matches,
raw_rpn_score, raw_rpn_box, anchors, top_feat)
return (cls_pred, box_pred, rpn_box, samples, matches,
raw_rpn_score, raw_rpn_box, anchors)
return (cls_pred, box_pred, rpn_box, samples, matches, raw_rpn_score, raw_rpn_box,
anchors, cls_targets, box_targets, box_masks, top_feat)
return (cls_pred, box_pred, rpn_box, samples, matches, raw_rpn_score, raw_rpn_box,
anchors, cls_targets, box_targets, box_masks)

# cls_ids (B, N, C), scores (B, N, C)
cls_ids, scores = self.cls_decoder(F.softmax(cls_pred, axis=-1))
Expand All @@ -419,7 +427,7 @@ def _split(x, axis, num_outputs, squeeze_axis):
results = []
for rpn_box, cls_id, score, box_pred in zip(rpn_boxes, cls_ids, scores, box_preds):
# box_pred (C, N, 4) rpn_box (1, N, 4) -> bbox (C, N, 4)
bbox = self.box_decoder(box_pred, self.box_to_center(rpn_box))
bbox = self.box_decoder(box_pred, rpn_box)
# res (C, N, 6)
res = F.concat(*[cls_id, score, bbox], dim=-1)
if self.force_nms:
Expand Down Expand Up @@ -683,7 +691,7 @@ def faster_rcnn_fpn_bn_resnet50_v1b_coco(pretrained=False, pretrained_base=True,
top_features = None
# 1 Conv 1 FC layer before RCNN cls and reg
box_features = nn.HybridSequential()
box_features.add(nn.Conv2D(256, 3, padding=1),
box_features.add(nn.Conv2D(256, 3, padding=1, use_bias=False),
SyncBatchNorm(**gluon_norm_kwargs),
nn.Activation('relu'),
nn.Dense(1024, weight_initializer=mx.init.Normal(0.01)),
Expand Down
19 changes: 11 additions & 8 deletions gluoncv/model_zoo/faster_rcnn/rcnn_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def hybrid_forward(self, F, rois, scores, gt_boxes):
Parameters
----------
rois: (B, self._num_input, 4) encoded in (x1, y1, x2, y2).
scores: (B, self._num_input, 1), value range [0, 1] with ignore value -1.
rois: (B, self._num_proposal, 4) encoded in (x1, y1, x2, y2).
scores: (B, self._num_proposal, 1), value range [0, 1] with ignore value -1.
gt_boxes: (B, M, 4) encoded in (x1, y1, x2, y2), invalid box should have area of 0.
Returns
Expand All @@ -65,7 +65,7 @@ def hybrid_forward(self, F, rois, scores, gt_boxes):
roi = F.squeeze(F.slice_axis(rois, axis=0, begin=i, end=i + 1), axis=0)
score = F.squeeze(F.slice_axis(scores, axis=0, begin=i, end=i + 1), axis=0)
gt_box = F.squeeze(F.slice_axis(gt_boxes, axis=0, begin=i, end=i + 1), axis=0)
gt_score = F.ones_like(F.sum(gt_box, axis=-1, keepdims=True))
gt_score = F.sign(F.sum(gt_box, axis=-1, keepdims=True) + 1)

# concat rpn roi with ground truth. mix gt with generated boxes.
all_roi = F.concat(roi, gt_box, dim=0)
Expand Down Expand Up @@ -126,9 +126,13 @@ def hybrid_forward(self, F, rois, scores, gt_boxes):
samples = F.concat(topk_samples, bottomk_samples, dim=0)
matches = F.concat(topk_matches, bottomk_matches, dim=0)

new_rois.append(all_roi.take(indices))
new_samples.append(samples)
new_matches.append(matches)
sampled_rois = all_roi.take(indices)
x1, y1, x2, y2 = F.split(sampled_rois, axis=-1, num_outputs=4, squeeze_axis=True)
rois_area = (x2 - x1) * (y2 - y1)
ind = F.argsort(rois_area)
new_rois.append(sampled_rois.take(ind))
new_samples.append(samples.take(ind))
new_matches.append(matches.take(ind))
# stack all samples together
new_rois = F.stack(*new_rois, axis=0)
new_samples = F.stack(*new_samples, axis=0)
Expand Down Expand Up @@ -179,6 +183,5 @@ def hybrid_forward(self, F, roi, samples, matches, gt_label, gt_box):
# cls_target (B, N)
cls_target = self._cls_encoder(samples, matches, gt_label)
# box_target, box_weight (C, B, N, 4)
box_target, box_mask = self._box_encoder(
samples, matches, roi, gt_label, gt_box)
box_target, box_mask = self._box_encoder(samples, matches, roi, gt_label, gt_box)
return cls_target, box_target, box_mask
14 changes: 9 additions & 5 deletions gluoncv/model_zoo/mask_rcnn/mask_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def __init__(self, features, top_features, classes, mask_channels=256, rcnn_max_
self.mask_target = MaskTargetGenerator(
self._batch_size, self._num_sample, self.num_class, self._target_roi_size)

def hybrid_forward(self, F, x, gt_box=None):
def hybrid_forward(self, F, x, gt_box=None, gt_label=None):
"""Forward Mask RCNN network.
The behavior during training and inference is different.
Expand All @@ -212,6 +212,10 @@ def hybrid_forward(self, F, x, gt_box=None):
The network input tensor.
gt_box : type, only required during training
The ground-truth bbox tensor with shape (1, N, 4).
gt_label : type, only required during training
The ground-truth label tensor with shape (B, 1, 4).
gt_label : type, only required during training
The ground-truth mask tensor.
Returns
-------
Expand All @@ -221,12 +225,12 @@ def hybrid_forward(self, F, x, gt_box=None):
"""
if autograd.is_training():
cls_pred, box_pred, rpn_box, samples, matches, \
raw_rpn_score, raw_rpn_box, anchors, top_feat = \
super(MaskRCNN, self).hybrid_forward(F, x, gt_box)
cls_pred, box_pred, rpn_box, samples, matches, raw_rpn_score, raw_rpn_box, anchors, \
cls_targets, box_targets, box_masks, top_feat = \
super(MaskRCNN, self).hybrid_forward(F, x, gt_box, gt_label)
mask_pred = self.mask(top_feat)
return cls_pred, box_pred, mask_pred, rpn_box, samples, matches, \
raw_rpn_score, raw_rpn_box, anchors
raw_rpn_score, raw_rpn_box, anchors, cls_targets, box_targets, box_masks
else:
batch_size = 1
ids, scores, boxes, feat = \
Expand Down
4 changes: 1 addition & 3 deletions gluoncv/model_zoo/rcnn/rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn
from ...nn.bbox import BBoxCornerToCenter
from ...nn.coder import NormalizedBoxCenterDecoder, MultiPerClassDecoder


Expand Down Expand Up @@ -101,8 +100,7 @@ def __init__(self, features, top_features, classes, box_features,
self.box_predictor = nn.Dense(
self.num_class * 4, weight_initializer=mx.init.Normal(0.001))
self.cls_decoder = MultiPerClassDecoder(num_class=self.num_class + 1)
self.box_to_center = BBoxCornerToCenter()
self.box_decoder = NormalizedBoxCenterDecoder(clip=clip)
self.box_decoder = NormalizedBoxCenterDecoder(clip=clip, convert_anchor=True)

def collect_train_params(self, select=None):
"""Collect trainable params.
Expand Down
2 changes: 1 addition & 1 deletion gluoncv/model_zoo/rpn/anchor.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def hybrid_forward(self, F, x, anchors):
- **out**: output anchor with (1, N, 4) shape. N is the number of anchors.
"""
a = F.slice_like(anchors, x * 0, axes=(2, 3))
a = F.slice_like(anchors, x, axes=(2, 3))
return a.reshape((1, -1, 4))


Expand Down
6 changes: 3 additions & 3 deletions gluoncv/model_zoo/rpn/proposal.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,19 @@ class RPNProposal(gluon.HybridBlock):
def __init__(self, clip, min_size, stds):
super(RPNProposal, self).__init__()
self._box_to_center = BBoxCornerToCenter()
self._box_decoder = NormalizedBoxCenterDecoder(stds=stds, clip=clip)
self._box_decoder = NormalizedBoxCenterDecoder(stds=stds, clip=clip, convert_anchor=True)
self._clipper = BBoxClipToImage()
# self._compute_area = BBoxArea()
self._min_size = min_size

# pylint: disable=arguments-differ
def hybrid_forward(self, F, anchor, score, bbox_pred, img):
"""
Generate proposals. Limit to batch-size=1 in current implementation.
Generate proposals.
"""
with autograd.pause():
# restore bounding boxes
roi = self._box_decoder(bbox_pred, self._box_to_center(anchor))
roi = self._box_decoder(bbox_pred, anchor)

# clip rois to image's boundary
# roi = F.Custom(roi, img, op_type='bbox_clip_to_image')
Expand Down
33 changes: 25 additions & 8 deletions gluoncv/model_zoo/rpn/rpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class RPN(gluon.HybridBlock):
The aspect ratios of anchor boxes. We expect it to be a list or tuple.
alloc_size : tuple of int
Allocate size for the anchor boxes as (H, W).
Usually we generate enough anchors for large feature map, e.g. 128x128.
Usually we generate enough anchors flt is or large feature map, e.g. 128x128.
Later in inference we can have variable input sizes,
at which time we can crop corresponding anchors from this large
anchor map so we can skip re-generating anchors for each input.
Expand All @@ -55,15 +55,21 @@ class RPN(gluon.HybridBlock):
Proposals whose size is smaller than ``min_size`` will be discarded.
multi_level : boolean
Whether to extract feature from multiple level. This is used in FPN.
multi_level : boolean, default is False.
Whether to use multiple feature maps for RPN. eg. FPN.
per_level_nms : boollean, default is False
Whether to apply nms on each level's rois instead of applying nms after aggregation.
"""

def __init__(self, channels, strides, base_size, scales, ratios, alloc_size,
clip, nms_thresh, train_pre_nms, train_post_nms,
test_pre_nms, test_post_nms, min_size, multi_level=False, **kwargs):
test_pre_nms, test_post_nms, min_size, multi_level=False, per_level_nms=False,
**kwargs):
super(RPN, self).__init__(**kwargs)
self._nms_thresh = nms_thresh
self._multi_level = multi_level
self._per_level_nms = per_level_nms
self._train_pre_nms = max(1, train_pre_nms)
self._train_post_nms = max(1, train_post_nms)
self._test_pre_nms = max(1, test_pre_nms)
Expand Down Expand Up @@ -131,8 +137,13 @@ def hybrid_forward(self, F, img, *x):
anchor = ag(feat)
rpn_score, rpn_box, raw_rpn_score, raw_rpn_box = \
self.rpn_head(feat)
rpn_pre = self.region_proposer(anchor, rpn_score,
rpn_box, img)
rpn_pre = self.region_proposer(anchor, rpn_score, rpn_box, img)
if self._per_level_nms:
with autograd.pause():
# Non-maximum suppression
rpn_pre = F.contrib.box_nms(rpn_pre, overlap_thresh=self._nms_thresh,
topk=pre_nms // len(x), coord_start=1,
score_index=0, id_index=-1)
anchors.append(anchor)
rpn_pre_nms_proposals.append(rpn_pre)
raw_rpn_scores.append(raw_rpn_score)
Expand All @@ -151,11 +162,17 @@ def hybrid_forward(self, F, img, *x):
rpn_pre_nms_proposals = self.region_proposer(
anchors, rpn_scores, rpn_boxes, img)

# Non-maximum suppression
with autograd.pause():
tmp = F.contrib.box_nms(rpn_pre_nms_proposals, overlap_thresh=self._nms_thresh,
topk=pre_nms, coord_start=1, score_index=0, id_index=-1,
force_suppress=True)
if self._per_level_nms and self._multi_level:
# sort by scores
tmp = F.contrib.box_nms(rpn_pre_nms_proposals, overlap_thresh=2.,
topk=pre_nms + 1, coord_start=1, score_index=0, id_index=-1,
force_suppress=True)
else:
# Non-maximum suppression
tmp = F.contrib.box_nms(rpn_pre_nms_proposals, overlap_thresh=self._nms_thresh,
topk=pre_nms, coord_start=1, score_index=0, id_index=-1,
force_suppress=True)

# slice post_nms number of boxes
result = F.slice_axis(tmp, axis=1, begin=0, end=post_nms)
Expand Down
12 changes: 6 additions & 6 deletions gluoncv/nn/bbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def __call__(self, x):
# this is different that detectron.
width = xmax - xmin
height = ymax - ymin
x = xmin + width / 2
y = ymin + height / 2
x = xmin + width * 0.5
y = ymin + height * 0.5
if not self._split:
return np.concatenate((x, y, width, height), axis=self._axis)
else:
Expand Down Expand Up @@ -71,8 +71,8 @@ def hybrid_forward(self, F, x):
# this is different that detectron.
width = xmax - xmin
height = ymax - ymin
x = xmin + width / 2
y = ymin + height / 2
x = xmin + width * 0.5
y = ymin + height * 0.5
if not self._split:
return F.concat(x, y, width, height, dim=self._axis)
else:
Expand Down Expand Up @@ -104,8 +104,8 @@ def __init__(self, axis=-1, split=False):
def hybrid_forward(self, F, x):
"""Hybrid forward"""
x, y, w, h = F.split(x, axis=self._axis, num_outputs=4)
hw = w / 2
hh = h / 2
hw = w * 0.5
hh = h * 0.5
xmin = x - hw
ymin = y - hh
xmax = x + hw
Expand Down
Loading

0 comments on commit a546c46

Please sign in to comment.