From a546c461fbcb07807b9971dd0b9aec3bb19d7455 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 3 Oct 2019 15:31:28 -0700 Subject: [PATCH] use new ops in https://github.com/apache/incubator-mxnet/pull/16215 --- README.md | 8 +-- gluoncv/model_zoo/faster_rcnn/faster_rcnn.py | 34 +++++++----- gluoncv/model_zoo/faster_rcnn/rcnn_target.py | 19 ++++--- gluoncv/model_zoo/mask_rcnn/mask_rcnn.py | 14 +++-- gluoncv/model_zoo/rcnn/rcnn.py | 4 +- gluoncv/model_zoo/rpn/anchor.py | 2 +- gluoncv/model_zoo/rpn/proposal.py | 6 +-- gluoncv/model_zoo/rpn/rpn.py | 33 +++++++++--- gluoncv/nn/bbox.py | 12 ++--- gluoncv/nn/coder.py | 53 ++++++++++-------- gluoncv/nn/feature.py | 4 +- .../faster_rcnn/train_faster_rcnn.py | 8 +-- scripts/instance/mask_rcnn/train_mask_rcnn.py | 54 +++++++++++-------- 13 files changed, 149 insertions(+), 102 deletions(-) diff --git a/README.md b/README.md index c2b8b473d2..2670fce8b2 100644 --- a/README.md +++ b/README.md @@ -53,8 +53,8 @@ The following commands install the stable version of GluonCV and MXNet: ```bash pip install gluoncv --upgrade pip install mxnet-mkl --upgrade -# if cuda 9.2 is installed -pip install mxnet-cu92mkl --upgrade +# if cuda 10.1 is installed +pip install mxnet-cu101mkl --upgrade ``` **The latest stable version of GluonCV is 0.4 and depends on mxnet >= 1.4.0** @@ -66,8 +66,8 @@ You may get access to latest features and bug fixes with the following commands ```bash pip install gluoncv --pre --upgrade pip install mxnet-mkl --pre --upgrade -# if cuda 9.2 is installed -pip install mxnet-cu92mkl --pre --upgrade +# if cuda 10.1 is installed +pip install mxnet-cu101mkl --pre --upgrade ``` There are multiple versions of MXNet pre-built package available. Please refer to [mxnet packages](https://gluon-crash-course.mxnet.io/mxnet_packages.html) if you need more details about MXNet versions. diff --git a/gluoncv/model_zoo/faster_rcnn/faster_rcnn.py b/gluoncv/model_zoo/faster_rcnn/faster_rcnn.py index 4c6c5084fa..ebeb0f9e15 100644 --- a/gluoncv/model_zoo/faster_rcnn/faster_rcnn.py +++ b/gluoncv/model_zoo/faster_rcnn/faster_rcnn.py @@ -207,7 +207,7 @@ def __init__(self, features, top_features, classes, box_features=None, clip=clip, nms_thresh=rpn_nms_thresh, train_pre_nms=rpn_train_pre_nms, train_post_nms=rpn_train_post_nms, test_pre_nms=rpn_test_pre_nms, test_post_nms=rpn_test_post_nms, min_size=rpn_min_size, - multi_level=self.num_stages > 1) + multi_level=self.num_stages > 1, per_level_nms=True) self.sampler = RCNNTargetSampler(num_image=self._batch_size, num_proposal=rpn_train_post_nms, num_sample=num_sample, pos_iou_thresh=pos_iou_thresh, pos_ratio=pos_ratio, @@ -292,16 +292,20 @@ def _pyramid_roi_feats(self, F, features, rpn_rois, roi_size, strides, roi_mode= # rpn_rois = F.take(rpn_rois, roi_level_sorted_args, axis=0) pooled_roi_feats = [] for i, l in enumerate(range(self._min_stage, max_stage + 1)): - # Pool features with all rois first, and then set invalid pooled features to zero, - # at last ele-wise add together to aggregate all features. if roi_mode == 'pool': + # Pool features with all rois first, and then set invalid pooled features to zero, + # at last ele-wise add together to aggregate all features. pooled_feature = F.ROIPooling(features[i], rpn_rois, roi_size, 1. / strides[i]) + pooled_feature = F.where(roi_level == l, pooled_feature, + F.zeros_like(pooled_feature)) elif roi_mode == 'align': - pooled_feature = F.contrib.ROIAlign(features[i], rpn_rois, roi_size, + masked_rpn_rois = F.where(roi_level == l, rpn_rois, F.ones_like(rpn_rois) * -1.) + pooled_feature = F.contrib.ROIAlign(features[i], masked_rpn_rois, roi_size, 1. / strides[i], sample_ratio=2) + # pooled_feature = F.where(roi_level == l, pooled_feature, + # F.zeros_like(pooled_feature)) else: raise ValueError("Invalid roi mode: {}".format(roi_mode)) - pooled_feature = F.where(roi_level == l, pooled_feature, F.zeros_like(pooled_feature)) pooled_roi_feats.append(pooled_feature) # Ele-wise add to aggregate all pooled features pooled_roi_feats = F.ElementWiseSum(*pooled_roi_feats) @@ -312,7 +316,7 @@ def _pyramid_roi_feats(self, F, features, rpn_rois, roi_size, strides, roi_mode= return pooled_roi_feats # pylint: disable=arguments-differ - def hybrid_forward(self, F, x, gt_box=None): + def hybrid_forward(self, F, x, gt_box=None, gt_label=None): """Forward Faster-RCNN network. The behavior during training and inference is different. @@ -322,7 +326,9 @@ def hybrid_forward(self, F, x, gt_box=None): x : mxnet.nd.NDArray or mxnet.symbol The network input tensor. gt_box : type, only required during training - The ground-truth bbox tensor with shape (1, N, 4). + The ground-truth bbox tensor with shape (B, N, 4). + gt_label : type, only required during training + The ground-truth label tensor with shape (B, 1, 4). Returns ------- @@ -393,11 +399,13 @@ def _split(x, axis, num_outputs, squeeze_axis): # no need to convert bounding boxes in training, just return if autograd.is_training(): + cls_targets, box_targets, box_masks = self.target_generator(rpn_box, samples, matches, + gt_label, gt_box) if self._additional_output: - return (cls_pred, box_pred, rpn_box, samples, matches, - raw_rpn_score, raw_rpn_box, anchors, top_feat) - return (cls_pred, box_pred, rpn_box, samples, matches, - raw_rpn_score, raw_rpn_box, anchors) + return (cls_pred, box_pred, rpn_box, samples, matches, raw_rpn_score, raw_rpn_box, + anchors, cls_targets, box_targets, box_masks, top_feat) + return (cls_pred, box_pred, rpn_box, samples, matches, raw_rpn_score, raw_rpn_box, + anchors, cls_targets, box_targets, box_masks) # cls_ids (B, N, C), scores (B, N, C) cls_ids, scores = self.cls_decoder(F.softmax(cls_pred, axis=-1)) @@ -419,7 +427,7 @@ def _split(x, axis, num_outputs, squeeze_axis): results = [] for rpn_box, cls_id, score, box_pred in zip(rpn_boxes, cls_ids, scores, box_preds): # box_pred (C, N, 4) rpn_box (1, N, 4) -> bbox (C, N, 4) - bbox = self.box_decoder(box_pred, self.box_to_center(rpn_box)) + bbox = self.box_decoder(box_pred, rpn_box) # res (C, N, 6) res = F.concat(*[cls_id, score, bbox], dim=-1) if self.force_nms: @@ -683,7 +691,7 @@ def faster_rcnn_fpn_bn_resnet50_v1b_coco(pretrained=False, pretrained_base=True, top_features = None # 1 Conv 1 FC layer before RCNN cls and reg box_features = nn.HybridSequential() - box_features.add(nn.Conv2D(256, 3, padding=1), + box_features.add(nn.Conv2D(256, 3, padding=1, use_bias=False), SyncBatchNorm(**gluon_norm_kwargs), nn.Activation('relu'), nn.Dense(1024, weight_initializer=mx.init.Normal(0.01)), diff --git a/gluoncv/model_zoo/faster_rcnn/rcnn_target.py b/gluoncv/model_zoo/faster_rcnn/rcnn_target.py index db727f64e3..56cedddfb1 100644 --- a/gluoncv/model_zoo/faster_rcnn/rcnn_target.py +++ b/gluoncv/model_zoo/faster_rcnn/rcnn_target.py @@ -45,8 +45,8 @@ def hybrid_forward(self, F, rois, scores, gt_boxes): Parameters ---------- - rois: (B, self._num_input, 4) encoded in (x1, y1, x2, y2). - scores: (B, self._num_input, 1), value range [0, 1] with ignore value -1. + rois: (B, self._num_proposal, 4) encoded in (x1, y1, x2, y2). + scores: (B, self._num_proposal, 1), value range [0, 1] with ignore value -1. gt_boxes: (B, M, 4) encoded in (x1, y1, x2, y2), invalid box should have area of 0. Returns @@ -65,7 +65,7 @@ def hybrid_forward(self, F, rois, scores, gt_boxes): roi = F.squeeze(F.slice_axis(rois, axis=0, begin=i, end=i + 1), axis=0) score = F.squeeze(F.slice_axis(scores, axis=0, begin=i, end=i + 1), axis=0) gt_box = F.squeeze(F.slice_axis(gt_boxes, axis=0, begin=i, end=i + 1), axis=0) - gt_score = F.ones_like(F.sum(gt_box, axis=-1, keepdims=True)) + gt_score = F.sign(F.sum(gt_box, axis=-1, keepdims=True) + 1) # concat rpn roi with ground truth. mix gt with generated boxes. all_roi = F.concat(roi, gt_box, dim=0) @@ -126,9 +126,13 @@ def hybrid_forward(self, F, rois, scores, gt_boxes): samples = F.concat(topk_samples, bottomk_samples, dim=0) matches = F.concat(topk_matches, bottomk_matches, dim=0) - new_rois.append(all_roi.take(indices)) - new_samples.append(samples) - new_matches.append(matches) + sampled_rois = all_roi.take(indices) + x1, y1, x2, y2 = F.split(sampled_rois, axis=-1, num_outputs=4, squeeze_axis=True) + rois_area = (x2 - x1) * (y2 - y1) + ind = F.argsort(rois_area) + new_rois.append(sampled_rois.take(ind)) + new_samples.append(samples.take(ind)) + new_matches.append(matches.take(ind)) # stack all samples together new_rois = F.stack(*new_rois, axis=0) new_samples = F.stack(*new_samples, axis=0) @@ -179,6 +183,5 @@ def hybrid_forward(self, F, roi, samples, matches, gt_label, gt_box): # cls_target (B, N) cls_target = self._cls_encoder(samples, matches, gt_label) # box_target, box_weight (C, B, N, 4) - box_target, box_mask = self._box_encoder( - samples, matches, roi, gt_label, gt_box) + box_target, box_mask = self._box_encoder(samples, matches, roi, gt_label, gt_box) return cls_target, box_target, box_mask diff --git a/gluoncv/model_zoo/mask_rcnn/mask_rcnn.py b/gluoncv/model_zoo/mask_rcnn/mask_rcnn.py index 760063e069..e7a1d41bb0 100644 --- a/gluoncv/model_zoo/mask_rcnn/mask_rcnn.py +++ b/gluoncv/model_zoo/mask_rcnn/mask_rcnn.py @@ -201,7 +201,7 @@ def __init__(self, features, top_features, classes, mask_channels=256, rcnn_max_ self.mask_target = MaskTargetGenerator( self._batch_size, self._num_sample, self.num_class, self._target_roi_size) - def hybrid_forward(self, F, x, gt_box=None): + def hybrid_forward(self, F, x, gt_box=None, gt_label=None): """Forward Mask RCNN network. The behavior during training and inference is different. @@ -212,6 +212,10 @@ def hybrid_forward(self, F, x, gt_box=None): The network input tensor. gt_box : type, only required during training The ground-truth bbox tensor with shape (1, N, 4). + gt_label : type, only required during training + The ground-truth label tensor with shape (B, 1, 4). + gt_label : type, only required during training + The ground-truth mask tensor. Returns ------- @@ -221,12 +225,12 @@ def hybrid_forward(self, F, x, gt_box=None): """ if autograd.is_training(): - cls_pred, box_pred, rpn_box, samples, matches, \ - raw_rpn_score, raw_rpn_box, anchors, top_feat = \ - super(MaskRCNN, self).hybrid_forward(F, x, gt_box) + cls_pred, box_pred, rpn_box, samples, matches, raw_rpn_score, raw_rpn_box, anchors, \ + cls_targets, box_targets, box_masks, top_feat = \ + super(MaskRCNN, self).hybrid_forward(F, x, gt_box, gt_label) mask_pred = self.mask(top_feat) return cls_pred, box_pred, mask_pred, rpn_box, samples, matches, \ - raw_rpn_score, raw_rpn_box, anchors + raw_rpn_score, raw_rpn_box, anchors, cls_targets, box_targets, box_masks else: batch_size = 1 ids, scores, boxes, feat = \ diff --git a/gluoncv/model_zoo/rcnn/rcnn.py b/gluoncv/model_zoo/rcnn/rcnn.py index 0593b337ae..9652239f7d 100644 --- a/gluoncv/model_zoo/rcnn/rcnn.py +++ b/gluoncv/model_zoo/rcnn/rcnn.py @@ -5,7 +5,6 @@ import mxnet as mx from mxnet import gluon from mxnet.gluon import nn -from ...nn.bbox import BBoxCornerToCenter from ...nn.coder import NormalizedBoxCenterDecoder, MultiPerClassDecoder @@ -101,8 +100,7 @@ def __init__(self, features, top_features, classes, box_features, self.box_predictor = nn.Dense( self.num_class * 4, weight_initializer=mx.init.Normal(0.001)) self.cls_decoder = MultiPerClassDecoder(num_class=self.num_class + 1) - self.box_to_center = BBoxCornerToCenter() - self.box_decoder = NormalizedBoxCenterDecoder(clip=clip) + self.box_decoder = NormalizedBoxCenterDecoder(clip=clip, convert_anchor=True) def collect_train_params(self, select=None): """Collect trainable params. diff --git a/gluoncv/model_zoo/rpn/anchor.py b/gluoncv/model_zoo/rpn/anchor.py index 2ad8e1acbb..5bf2a9732a 100644 --- a/gluoncv/model_zoo/rpn/anchor.py +++ b/gluoncv/model_zoo/rpn/anchor.py @@ -82,7 +82,7 @@ def hybrid_forward(self, F, x, anchors): - **out**: output anchor with (1, N, 4) shape. N is the number of anchors. """ - a = F.slice_like(anchors, x * 0, axes=(2, 3)) + a = F.slice_like(anchors, x, axes=(2, 3)) return a.reshape((1, -1, 4)) diff --git a/gluoncv/model_zoo/rpn/proposal.py b/gluoncv/model_zoo/rpn/proposal.py index 3b62be23d5..5bd697a6bb 100644 --- a/gluoncv/model_zoo/rpn/proposal.py +++ b/gluoncv/model_zoo/rpn/proposal.py @@ -30,7 +30,7 @@ class RPNProposal(gluon.HybridBlock): def __init__(self, clip, min_size, stds): super(RPNProposal, self).__init__() self._box_to_center = BBoxCornerToCenter() - self._box_decoder = NormalizedBoxCenterDecoder(stds=stds, clip=clip) + self._box_decoder = NormalizedBoxCenterDecoder(stds=stds, clip=clip, convert_anchor=True) self._clipper = BBoxClipToImage() # self._compute_area = BBoxArea() self._min_size = min_size @@ -38,11 +38,11 @@ def __init__(self, clip, min_size, stds): # pylint: disable=arguments-differ def hybrid_forward(self, F, anchor, score, bbox_pred, img): """ - Generate proposals. Limit to batch-size=1 in current implementation. + Generate proposals. """ with autograd.pause(): # restore bounding boxes - roi = self._box_decoder(bbox_pred, self._box_to_center(anchor)) + roi = self._box_decoder(bbox_pred, anchor) # clip rois to image's boundary # roi = F.Custom(roi, img, op_type='bbox_clip_to_image') diff --git a/gluoncv/model_zoo/rpn/rpn.py b/gluoncv/model_zoo/rpn/rpn.py index 9209af21d8..37cd77050e 100644 --- a/gluoncv/model_zoo/rpn/rpn.py +++ b/gluoncv/model_zoo/rpn/rpn.py @@ -35,7 +35,7 @@ class RPN(gluon.HybridBlock): The aspect ratios of anchor boxes. We expect it to be a list or tuple. alloc_size : tuple of int Allocate size for the anchor boxes as (H, W). - Usually we generate enough anchors for large feature map, e.g. 128x128. + Usually we generate enough anchors flt is or large feature map, e.g. 128x128. Later in inference we can have variable input sizes, at which time we can crop corresponding anchors from this large anchor map so we can skip re-generating anchors for each input. @@ -55,15 +55,21 @@ class RPN(gluon.HybridBlock): Proposals whose size is smaller than ``min_size`` will be discarded. multi_level : boolean Whether to extract feature from multiple level. This is used in FPN. + multi_level : boolean, default is False. + Whether to use multiple feature maps for RPN. eg. FPN. + per_level_nms : boollean, default is False + Whether to apply nms on each level's rois instead of applying nms after aggregation. """ def __init__(self, channels, strides, base_size, scales, ratios, alloc_size, clip, nms_thresh, train_pre_nms, train_post_nms, - test_pre_nms, test_post_nms, min_size, multi_level=False, **kwargs): + test_pre_nms, test_post_nms, min_size, multi_level=False, per_level_nms=False, + **kwargs): super(RPN, self).__init__(**kwargs) self._nms_thresh = nms_thresh self._multi_level = multi_level + self._per_level_nms = per_level_nms self._train_pre_nms = max(1, train_pre_nms) self._train_post_nms = max(1, train_post_nms) self._test_pre_nms = max(1, test_pre_nms) @@ -131,8 +137,13 @@ def hybrid_forward(self, F, img, *x): anchor = ag(feat) rpn_score, rpn_box, raw_rpn_score, raw_rpn_box = \ self.rpn_head(feat) - rpn_pre = self.region_proposer(anchor, rpn_score, - rpn_box, img) + rpn_pre = self.region_proposer(anchor, rpn_score, rpn_box, img) + if self._per_level_nms: + with autograd.pause(): + # Non-maximum suppression + rpn_pre = F.contrib.box_nms(rpn_pre, overlap_thresh=self._nms_thresh, + topk=pre_nms // len(x), coord_start=1, + score_index=0, id_index=-1) anchors.append(anchor) rpn_pre_nms_proposals.append(rpn_pre) raw_rpn_scores.append(raw_rpn_score) @@ -151,11 +162,17 @@ def hybrid_forward(self, F, img, *x): rpn_pre_nms_proposals = self.region_proposer( anchors, rpn_scores, rpn_boxes, img) - # Non-maximum suppression with autograd.pause(): - tmp = F.contrib.box_nms(rpn_pre_nms_proposals, overlap_thresh=self._nms_thresh, - topk=pre_nms, coord_start=1, score_index=0, id_index=-1, - force_suppress=True) + if self._per_level_nms and self._multi_level: + # sort by scores + tmp = F.contrib.box_nms(rpn_pre_nms_proposals, overlap_thresh=2., + topk=pre_nms + 1, coord_start=1, score_index=0, id_index=-1, + force_suppress=True) + else: + # Non-maximum suppression + tmp = F.contrib.box_nms(rpn_pre_nms_proposals, overlap_thresh=self._nms_thresh, + topk=pre_nms, coord_start=1, score_index=0, id_index=-1, + force_suppress=True) # slice post_nms number of boxes result = F.slice_axis(tmp, axis=1, begin=0, end=post_nms) diff --git a/gluoncv/nn/bbox.py b/gluoncv/nn/bbox.py index a5a7ec79dc..d8134bf5c7 100644 --- a/gluoncv/nn/bbox.py +++ b/gluoncv/nn/bbox.py @@ -34,8 +34,8 @@ def __call__(self, x): # this is different that detectron. width = xmax - xmin height = ymax - ymin - x = xmin + width / 2 - y = ymin + height / 2 + x = xmin + width * 0.5 + y = ymin + height * 0.5 if not self._split: return np.concatenate((x, y, width, height), axis=self._axis) else: @@ -71,8 +71,8 @@ def hybrid_forward(self, F, x): # this is different that detectron. width = xmax - xmin height = ymax - ymin - x = xmin + width / 2 - y = ymin + height / 2 + x = xmin + width * 0.5 + y = ymin + height * 0.5 if not self._split: return F.concat(x, y, width, height, dim=self._axis) else: @@ -104,8 +104,8 @@ def __init__(self, axis=-1, split=False): def hybrid_forward(self, F, x): """Hybrid forward""" x, y, w, h = F.split(x, axis=self._axis, num_outputs=4) - hw = w / 2 - hh = h / 2 + hw = w * 0.5 + hh = h * 0.5 xmin = x - hw ymin = y - hh xmax = x + hw diff --git a/gluoncv/nn/coder.py b/gluoncv/nn/coder.py index ec9ef31889..daa85643c1 100644 --- a/gluoncv/nn/coder.py +++ b/gluoncv/nn/coder.py @@ -8,6 +8,7 @@ import numpy as np from mxnet import gluon +from mxnet import nd from .bbox import BBoxCornerToCenter, NumPyBBoxCornerToCenter @@ -95,14 +96,16 @@ class NormalizedBoxCenterEncoder(gluon.HybridBlock): """ - def __init__(self, stds=(0.1, 0.1, 0.2, 0.2), means=(0., 0., 0., 0.)): - super(NormalizedBoxCenterEncoder, self).__init__() + def __init__(self, stds=(0.1, 0.1, 0.2, 0.2), means=(0., 0., 0., 0.), **kwargs): + super(NormalizedBoxCenterEncoder, self).__init__(**kwargs) assert len(stds) == 4, "Box Encoder requires 4 std values." - self._stds = stds + assert len(means) == 4, "Box Encoder requires 4 std values." self._means = means + self._stds = stds with self.name_scope(): self.corner_to_center = BBoxCornerToCenter(split=True) + # pylint: disable=arguments-differ def hybrid_forward(self, F, samples, matches, anchors, refs): """Not HybridBlock due to use of matches.shape @@ -166,8 +169,11 @@ def __init__(self, num_class, stds=(0.1, 0.1, 0.2, 0.2), means=(0., 0., 0., 0.)) self._num_class = num_class with self.name_scope(): self.class_agnostic_encoder = NormalizedBoxCenterEncoder(stds=stds, means=means) + if 'box_encode' in nd.contrib.__dict__: + self.means = self.params.get_constant('means', means) + self.stds = self.params.get_constant('stds', stds) - def hybrid_forward(self, F, samples, matches, anchors, labels, refs): + def hybrid_forward(self, F, samples, matches, anchors, labels, refs, means=None, stds=None): """Encode BBox One entry per category Parameters @@ -186,7 +192,11 @@ def hybrid_forward(self, F, samples, matches, anchors, labels, refs): """ # refs [B, M, 4], anchors [B, N, 4], samples [B, N], matches [B, N] # encoded targets [B, N, 4], masks [B, N, 4] - targets, masks = self.class_agnostic_encoder(samples, matches, anchors, refs) + if 'box_encode' in F.contrib.__dict__: + targets, masks = F.contrib.box_encode(samples, matches, anchors, refs, means, stds) + else: + targets, masks = self.class_agnostic_encoder(samples, matches, anchors, refs) + # labels [B, M] -> [B, N, M] ref_labels = F.broadcast_like(labels.reshape((0, 1, -1)), matches, lhs_axes=1, rhs_axes=1) # labels [B, N, M] -> pick from matches [B, N] -> [B, N, 1] @@ -212,42 +222,45 @@ class NormalizedBoxCenterDecoder(gluon.HybridBlock): ---------- stds : array-like of size 4 Std value to be divided from encoded values, default is (0.1, 0.1, 0.2, 0.2). - means : array-like of size 4 - Mean value to be subtracted from encoded values, default is (0., 0., 0., 0.). - clip: float, default is None + clip : float, default is None If given, bounding box target will be clipped to this value. + convert_anchor : boolean, default is False + Whether to convert anchor from corner to center format. """ - def __init__(self, stds=(0.1, 0.1, 0.2, 0.2), means=(0., 0., 0., 0.), - convert_anchor=False, clip=None): + def __init__(self, stds=(0.1, 0.1, 0.2, 0.2), convert_anchor=False, clip=None): super(NormalizedBoxCenterDecoder, self).__init__() assert len(stds) == 4, "Box Encoder requires 4 std values." self._stds = stds - self._means = means self._clip = clip if convert_anchor: self.corner_to_center = BBoxCornerToCenter(split=True) else: self.corner_to_center = None + self._format = 'corner' if convert_anchor else 'center' def hybrid_forward(self, F, x, anchors): + if 'box_decode' in F.contrib.__dict__: + x, anchors = F.amp_multicast(x, anchors, num_outputs=2, cast_narrow=True) + return F.contrib.box_decode(x, anchors, self._stds[0], self._stds[1], self._stds[2], + self._stds[3], clip=self._clip, format=self._format) if self.corner_to_center is not None: a = self.corner_to_center(anchors) else: a = anchors.split(axis=-1, num_outputs=4) p = F.split(x, axis=-1, num_outputs=4) - ox = F.broadcast_add(F.broadcast_mul(p[0] * self._stds[0] + self._means[0], a[2]), a[0]) - oy = F.broadcast_add(F.broadcast_mul(p[1] * self._stds[1] + self._means[1], a[3]), a[1]) - dw = p[2] * self._stds[2] + self._means[2] - dh = p[3] * self._stds[3] + self._means[3] + ox = F.broadcast_add(F.broadcast_mul(p[0] * self._stds[0], a[2]), a[0]) + oy = F.broadcast_add(F.broadcast_mul(p[1] * self._stds[1], a[3]), a[1]) + dw = p[2] * self._stds[2] + dh = p[3] * self._stds[3] if self._clip: dw = F.minimum(dw, self._clip) dh = F.minimum(dh, self._clip) dw = F.exp(dw) dh = F.exp(dh) - ow = F.broadcast_mul(dw, a[2]) / 2 - oh = F.broadcast_mul(dh, a[3]) / 2 + ow = F.broadcast_mul(dw, a[2]) * 0.5 + oh = F.broadcast_mul(dh, a[3]) * 0.5 return F.concat(ox - ow, oy - oh, ox + ow, oy + oh, dim=-1) @@ -366,10 +379,8 @@ def __init__(self, num_class, axis=-1, thresh=0.01): def hybrid_forward(self, F, x): scores = x.slice_axis(axis=self._axis, begin=1, end=None) # b x N x fg_class template = F.zeros_like(x.slice_axis(axis=-1, begin=0, end=1)) - cls_ids = [] - for i in range(self._fg_class): - cls_ids.append(template + i) # b x N x 1 - cls_id = F.concat(*cls_ids, dim=-1) # b x N x fg_class + cls_id = F.broadcast_add(template, + F.reshape(F.arange(self._fg_class), shape=(1, 1, self._fg_class))) mask = scores > self._thresh cls_id = F.where(mask, cls_id, F.ones_like(cls_id) * -1) scores = F.where(mask, scores, F.zeros_like(scores)) diff --git a/gluoncv/nn/feature.py b/gluoncv/nn/feature.py index d36c1cdee8..a89d278755 100644 --- a/gluoncv/nn/feature.py +++ b/gluoncv/nn/feature.py @@ -244,8 +244,8 @@ def __init__(self, network, outputs, num_filters, use_1x1=True, use_upsample=Tru attr={'__init__': weight_init}) if norm_layer is not None: if norm_layer is SyncBatchNorm: - norm_kwargs['key'] = "P{}_conv1_bn".format(num_stages - i) - norm_kwargs['name'] = "P{}_conv1_bn".format(num_stages - i) + norm_kwargs['key'] = "P{}_lat_bn".format(num_stages - i) + norm_kwargs['name'] = "P{}_lat_bn".format(num_stages - i) bf = norm_layer(bf, **norm_kwargs) if use_upsample: y = mx.sym.UpSampling(y, scale=2, sample_type='nearest', diff --git a/scripts/detection/faster_rcnn/train_faster_rcnn.py b/scripts/detection/faster_rcnn/train_faster_rcnn.py index 352a5ea8c5..e0c54b93c2 100644 --- a/scripts/detection/faster_rcnn/train_faster_rcnn.py +++ b/scripts/detection/faster_rcnn/train_faster_rcnn.py @@ -264,8 +264,8 @@ def forward_backward(self, x): with autograd.record(): gt_label = label[:, :, 4:5] gt_box = label[:, :, :4] - cls_pred, box_pred, roi, samples, matches, rpn_score, rpn_box, anchors = net( - data, gt_box) + cls_pred, box_pred, roi, samples, matches, rpn_score, rpn_box, anchors, cls_targets, \ + box_targets, box_masks = net(data, gt_box, gt_label) # losses of rpn rpn_score = rpn_score.squeeze(axis=-1) num_rpn_pos = (rpn_cls_targets >= 0).sum() @@ -275,10 +275,6 @@ def forward_backward(self, x): rpn_box_masks) * rpn_box.size / num_rpn_pos # rpn overall loss, use sum rather than average rpn_loss = rpn_loss1 + rpn_loss2 - # generate targets for rcnn - cls_targets, box_targets, box_masks = self.net.target_generator(roi, samples, - matches, gt_label, - gt_box) # losses of rcnn num_rcnn_pos = (cls_targets >= 0).sum() rcnn_loss1 = self.rcnn_cls_loss(cls_pred, cls_targets, diff --git a/scripts/instance/mask_rcnn/train_mask_rcnn.py b/scripts/instance/mask_rcnn/train_mask_rcnn.py index e1cd6ea664..85aa59c8e8 100644 --- a/scripts/instance/mask_rcnn/train_mask_rcnn.py +++ b/scripts/instance/mask_rcnn/train_mask_rcnn.py @@ -29,6 +29,8 @@ hvd = None +# from mxnet import profiler + def parse_args(): parser = argparse.ArgumentParser(description='Train Mask R-CNN network end to end.') parser.add_argument('--network', type=str, default='resnet50_v1b', @@ -51,13 +53,13 @@ def parse_args(): help='Starting epoch for resuming, default is 0 for new training.' 'You can specify it to 100 for example to start from 100 epoch.') parser.add_argument('--lr', type=str, default='', - help='Learning rate, default is 0.00125 for coco single gpu training.') + help='Learning rate, default is 0.01 for coco 8 gpus training.') parser.add_argument('--lr-decay', type=float, default=0.1, help='decay rate of learning rate. default is 0.1.') parser.add_argument('--lr-decay-epoch', type=str, default='', help='epochs at which learning rate decays. default is 17,23 for coco.') parser.add_argument('--lr-warmup', type=str, default='', - help='warmup iterations to adjust learning rate, default is 8000 for coco.') + help='warmup iterations to adjust learning rate, default is 1000 for coco.') parser.add_argument('--lr-warmup-factor', type=float, default=1. / 3., help='warmup factor of base lr.') parser.add_argument('--momentum', type=float, default=0.9, @@ -168,15 +170,14 @@ def save_params(net, logger, best_map, current_map, epoch, save_interval, prefix net.save_parameters('{:s}_{:04d}_{:.4f}.params'.format(prefix, epoch, current_map)) -pinned_data_stage = {} - -def _stage_data(i, data, ctx_list, stage_data): +def _stage_data(i, data, ctx_list, pinned_data_stage): def _get_chunk(data, storage): s = storage.reshape(shape=(storage.size,)) s = s[:data.size] s = s.reshape(shape=data.shape) data.copyto(s) return s + if ctx_list[0].device_type == "cpu": return data if i not in pinned_data_stage: @@ -189,7 +190,11 @@ def _get_chunk(data, storage): if data[j].size > storage[j].size: storage[j] = data[j].as_in_context(mx.cpu_pinned()) - return [_get_chunk(d, s) for d,s in zip(data, storage)] + return [_get_chunk(d, s) for d, s in zip(data, storage)] + + +pinned_data_stage = {} + def split_and_load(batch, ctx_list): """Split data to 1 batch each device.""" @@ -275,8 +280,8 @@ def forward_backward(self, x): with autograd.record(): gt_label = label[:, :, 4:5] gt_box = label[:, :, :4] - cls_pred, box_pred, mask_pred, roi, samples, matches, rpn_score, rpn_box, anchors = net( - data, gt_box) + cls_pred, box_pred, mask_pred, roi, samples, matches, rpn_score, rpn_box, anchors, \ + cls_targets, box_targets, box_masks = net(data, gt_box, gt_label) # losses of rpn rpn_score = rpn_score.squeeze(axis=-1) num_rpn_pos = (rpn_cls_targets >= 0).sum() @@ -286,10 +291,7 @@ def forward_backward(self, x): rpn_box_masks) * rpn_box.size / num_rpn_pos # rpn overall loss, use sum rather than average rpn_loss = rpn_loss1 + rpn_loss2 - # generate targets for rcnn - cls_targets, box_targets, box_masks = self.net.target_generator(roi, samples, - matches, gt_label, - gt_box) + # losses of rcnn num_rcnn_pos = (cls_targets >= 0).sum() rcnn_loss1 = self.rcnn_cls_loss(cls_pred, cls_targets, @@ -335,10 +337,11 @@ def train(net, train_data, val_data, eval_metric, batch_size, ctx, args): """Training pipeline""" kv = mx.kvstore.create('device' if (args.amp and 'nccl' in args.kv_store) else args.kv_store) net.collect_params().setattr('grad_req', 'null') - net.collect_train_params().setattr('grad_req', 'write') + net.collect_train_params().setattr('grad_req', 'add') for k, v in net.collect_params('.*bias').items(): v.wd_mult = 0.0 - optimizer_params = {'learning_rate': args.lr, 'wd': args.wd, 'momentum': args.momentum} + optimizer_params = {'learning_rate': args.lr, 'wd': args.wd, 'momentum': args.momentum,} + #'clip_gradient': 1.5} if args.horovod: hvd.broadcast_parameters(net.collect_params(), root_rank=0) trainer = hvd.DistributedTrainer( @@ -416,11 +419,9 @@ def train(net, train_data, val_data, eval_metric, batch_size, ctx, args): btic = time.time() base_lr = trainer.learning_rate train_data_iter = iter(train_data) - end_of_batch = False next_data_batch = next(train_data_iter) next_data_batch = split_and_load(next_data_batch, ctx_list=ctx) - i = 0 - while not end_of_batch: + for i in range(len(train_data)): batch = next_data_batch if epoch == 0 and i <= lr_warmup: # adjust based on real percentage @@ -432,7 +433,6 @@ def train(net, train_data, val_data, eval_metric, batch_size, ctx, args): trainer.set_learning_rate(new_lr) metric_losses = [[] for _ in metrics] add_losses = [[] for _ in metrics2] - # losses = [] if executor is not None: for data in zip(*batch): executor.put(data) @@ -451,22 +451,23 @@ def train(net, train_data, val_data, eval_metric, batch_size, ctx, args): next_data_batch = next(train_data_iter) next_data_batch = split_and_load(next_data_batch, ctx_list=ctx) except StopIteration: - end_of_batch = True + pass for metric, record in zip(metrics, metric_losses): metric.update(0, record) for metric, records in zip(metrics2, add_losses): for pred in records: metric.update(pred[0], pred[1]) - trainer.step(batch_size) - # update metrics + if i % 4 == 0: + trainer.step(batch_size * 4) + net.collect_train_params().zero_grad() + if (not args.horovod or hvd.rank() == 0) and args.log_interval \ and not (i + 1) % args.log_interval: msg = ','.join(['{}={:.3f}'.format(*metric.get()) for metric in metrics + metrics2]) logger.info('[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}'.format( epoch, i, args.log_interval * args.batch_size / (time.time() - btic), msg)) btic = time.time() - i = i + 1 # validate and save params if (not args.horovod) or hvd.rank() == 0: msg = ','.join(['{}={:.3f}'.format(*metric.get()) for metric in metrics]) @@ -521,6 +522,15 @@ def train(net, train_data, val_data, eval_metric, batch_size, ctx, args): continue param.initialize() net.collect_params().reset_ctx(ctx) + ''' + sym = net(mx.sym.var(name='data')) + print(sym) + a = mx.viz.plot_network(mx.sym.concat(*sym), + node_attrs={'shape': 'rect', 'fixedsize': 'false'}) + import tempfile + + a.view(tempfile.mktemp('.gv')) + a.render('fpn')''' # training data train_dataset, val_dataset, eval_metric = get_dataset(args.dataset, args)