Skip to content

Commit

Permalink
Rename the op and move to contrib
Browse files Browse the repository at this point in the history
Signed-off-by: Serge Panev <[email protected]>
  • Loading branch information
Kh4L committed Oct 15, 2019
1 parent 577817c commit d7bdd5c
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions gluoncv/model_zoo/mask_rcnn/mask_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def reset_class(self, classes, reuse_weights=None):
# set data to new conv layers
new_params.set_data(new_data)

class MRCNNTargetWrapper(nn.HybridBlock):
class MRCNNMaskTargetWrapper(nn.HybridBlock):
"""Wrapper for the fused Mask RCNN mask targets generator.
Parameters
Expand All @@ -169,16 +169,16 @@ class MRCNNTargetWrapper(nn.HybridBlock):
"""

def __init__(self, num_rois, num_classes, mask_size, **kwargs):
super(MRCNNTargetWrapper, self).__init__(**kwargs)
super(MRCNNMaskTargetWrapper, self).__init__(**kwargs)
self._num_rois = num_rois
self._num_classes = num_classes
self._mask_size = mask_size[0]

def hybrid_forward(self, F, roi, gt_mask, matches, cls_targets):
return F.mrcnn_target(roi, gt_mask, matches, cls_targets,
num_rois=self._num_rois,
num_classes=self._num_classes,
mask_size=self._mask_size)
return F.contrib.mrcnn_mask_target(roi, gt_mask, matches, cls_targets,
num_rois=self._num_rois,
num_classes=self._num_classes,
mask_size=self._mask_size)

class MaskRCNN(FasterRCNN):
r"""Mask RCNN network.
Expand Down Expand Up @@ -222,8 +222,8 @@ def __init__(self, features, top_features, classes, mask_channels=256, rcnn_max_
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
roi_size = (self._roi_size[0] * target_roi_scale, self._roi_size[1] * target_roi_scale)
self._target_roi_size = roi_size
if hasattr(mx.nd, 'mrcnn_target'):
self.mask_target = MRCNNTargetWrapper(
if hasattr(mx.nd, 'mrcnn_mask_target'):
self.mask_target = MRCNNMaskTargetWrapper(
self._num_sample, self.num_class, self._target_roi_size)
else: # for backward compatibility
self.mask_target = MaskTargetGenerator(
Expand Down

0 comments on commit d7bdd5c

Please sign in to comment.