diff --git a/src/operator/contrib/mrcnn_target-inl.h b/src/operator/contrib/mrcnn_target-inl.h new file mode 100644 index 000000000000..a3d8c444086a --- /dev/null +++ b/src/operator/contrib/mrcnn_target-inl.h @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file mrcnn_target-inl.h + * \brief Mask-RCNN target generator + * \author Serge Panev + */ + + +#ifndef MXNET_OPERATOR_CONTRIB_MRCNN_TARGET_INL_H_ +#define MXNET_OPERATOR_CONTRIB_MRCNN_TARGET_INL_H_ + +#include +#include +#include "../operator_common.h" +#include "../mshadow_op.h" +#include "../tensor/init_op.h" + +namespace mxnet { +namespace op { + +namespace mrcnn_index { + enum ROIAlignOpInputs {kRoi, kGtMask, kMatches, kClasses}; + enum ROIAlignOpOutputs {kMask, kMaskClasses}; +} // namespace mrcnn_index + +struct MRCNNTargetParam : public dmlc::Parameter { + int num_rois; + int num_classes; + int mask_size; + int sample_ratio; + + DMLC_DECLARE_PARAMETER(MRCNNTargetParam) { + DMLC_DECLARE_FIELD(num_rois) + .describe("Number of sampled RoIs."); + DMLC_DECLARE_FIELD(num_classes) + .describe("Number of classes."); + DMLC_DECLARE_FIELD(mask_size) + .describe("Size of the pooled masks."); + DMLC_DECLARE_FIELD(sample_ratio).set_default(2) + .describe("Sampling ratio of ROI align. Set to -1 to use adaptative size."); + } +}; + +inline bool MRCNNTargetShape(const NodeAttrs& attrs, + std::vector* in_shape, + std::vector* out_shape) { + using namespace mshadow; + const MRCNNTargetParam& param = nnvm::get(attrs.parsed); + + CHECK_EQ(in_shape->size(), 4) << "Input:[rois, gt_masks, matches, cls_targets]"; + + // (B, N, 4) + mxnet::TShape tshape = in_shape->at(mrcnn_index::kRoi); + CHECK_EQ(tshape.ndim(), 3) << "rois should be a 2D tensor of shape [batch, rois, 4]"; + CHECK_EQ(tshape[2], 4) << "rois should be a 2D tensor of shape [batch, rois, 4]"; + auto batch_size = tshape[0]; + auto num_rois = tshape[1]; + + // (B, M, H, W) + tshape = in_shape->at(mrcnn_index::kGtMask); + CHECK_EQ(tshape.ndim(), 4) << "gt_masks should be a 4D tensor"; + CHECK_EQ(tshape[0], batch_size) << " batch size should be the same for all the inputs."; + + // (B, N) + tshape = in_shape->at(mrcnn_index::kMatches); + CHECK_EQ(tshape.ndim(), 2) << "matches should be a 2D tensor"; + CHECK_EQ(tshape[0], batch_size) << " batch size should be the same for all the inputs."; + + // (B, N) + tshape = in_shape->at(mrcnn_index::kClasses); + CHECK_EQ(tshape.ndim(), 2) << "matches should be a 2D tensor"; + CHECK_EQ(tshape[0], batch_size) << " batch size should be the same for all the inputs."; + + // out: 2 * (B, N, C, MS, MS) + auto oshape = Shape5(batch_size, num_rois, param.num_classes, param.mask_size, param.mask_size); + out_shape->clear(); + out_shape->push_back(oshape); + out_shape->push_back(oshape); + return true; +} + +inline bool MRCNNTargetType(const NodeAttrs& attrs, + std::vector* in_type, + std::vector* out_type) { + CHECK_EQ(in_type->size(), 4); + int dtype = (*in_type)[1]; + CHECK_NE(dtype, -1) << "Input must have specified type"; + + out_type->clear(); + out_type->push_back(dtype); + out_type->push_back(dtype); + return true; +} + +template +void MRCNNTargetRun(const MRCNNTargetParam& param, const std::vector &inputs, + const std::vector &outputs, mshadow::Stream *s); + +template +void MRCNNTargetCompute(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + auto s = ctx.get_stream(); + const auto& p = dmlc::get(attrs.parsed); + MRCNNTargetRun(p, inputs, outputs, s); +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_CONTRIB_MRCNN_TARGET_INL_H_ diff --git a/src/operator/contrib/mrcnn_target.cu b/src/operator/contrib/mrcnn_target.cu new file mode 100644 index 000000000000..d542c9c220f8 --- /dev/null +++ b/src/operator/contrib/mrcnn_target.cu @@ -0,0 +1,278 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file mrcnn_target.cu + * \brief Mask-RCNN target generator + * \author Serge Panev + */ + +#include "./mrcnn_target-inl.h" + +namespace mxnet { +namespace op { + +using namespace mshadow::cuda; + +// The maximum number of blocks to use in the default kernel call. +constexpr int MAXIMUM_NUM_BLOCKS = 4096; + +inline int CUDA_GET_BLOCKS(const int N) { + return std::max( + std::min( + (N + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock, + MAXIMUM_NUM_BLOCKS), + // Use at least 1 block, since CUDA does not allow empty block. + 1); +} + +// Kernels + +template +__device__ T bilinear_interpolate( + const T* in_data, + const int height, + const int width, + T y, + T x, + const int index /* index for debug only*/) { + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + return 0; + } + + if (y <= 0) { + y = 0; + } + if (x <= 0) { + x = 0; + } + + int y_low = static_cast(y); + int x_low = static_cast(x); + int y_high; + int x_high; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + // do bilinear interpolation + T v1 = in_data[y_low * width + x_low]; + T v2 = in_data[y_low * width + x_high]; + T v3 = in_data[y_high * width + x_low]; + T v4 = in_data[y_high * width + x_high]; + T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + return val; +} + +// Modified version of RoIAlignForwardKernel from Caffe (in roi_align.cu) +// Main modifications: +// - We don't need position_sensitive neither spatial_scale from the original RoIAlign kernel. +// - We replace `channels` by `num_classes` and modify the logic consequently (e.g. offset_in_data +// does not use `c` anymore). +template +__device__ void RoIAlignForward( + const T* in_data, // (B, M, H, W) + const T* rois, // (B, N, 4) + const T* matches, // (B, N) + const int num_el, + const int num_classes, + const int height, + const int width, + const int pooled_height, + const int pooled_width, + const int sampling_ratio, + const int num_rois, + const int num_gtmasks, + T* out_data) { // (B, N, C, H, W) + // Update kernel + for (size_t index = blockIdx.x * blockDim.x + threadIdx.x; + index < num_el; + index += blockDim.x * gridDim.x) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + // int c = (index / pooled_width / pooled_height) % num_classes; + int n = (index / pooled_width / pooled_height / num_classes) % num_rois; + int batch_idx = (index / pooled_width / pooled_height / num_classes / num_rois); + + int roi_batch_ind = matches[batch_idx * num_rois + n]; + + const T* offset_rois = rois + batch_idx * (4 * num_rois) + n * 4; + // Do not using rounding; this implementation detail is critical + T roi_start_w = offset_rois[0]; + T roi_start_h = offset_rois[1]; + T roi_end_w = offset_rois[2]; + T roi_end_h = offset_rois[3]; + + // Force malformed ROIs to be 1x1 + T roi_width = max(roi_end_w - roi_start_w, (T)1.); + T roi_height = max(roi_end_h - roi_start_h, (T)1.); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + const T* offset_in_data = + in_data + batch_idx * num_gtmasks * height * width + + roi_batch_ind * height * width; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // We do average (integral) pooling inside a bin + const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 + + T output_val = 0.; + for (int iy = 0; iy < roi_bin_grid_h; iy++) { // e.g., iy = 0, 1 + const T y = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T x = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + T val = bilinear_interpolate( + offset_in_data, height, width, y, x, index); + output_val += val; + } + } + output_val /= count; + + out_data[index] = output_val; + } +} + + +template +__global__ void MRCNNTargetKernel(const DType *rois, + const DType *gt_masks, + const DType *matches, + const DType *cls_targets, + DType* sampled_masks, + DType* mask_cls, + const int total_out_el, + int batch_size, + int num_classes, + int num_rois, + int num_gtmasks, + int gt_height, + int gt_width, + int mask_size, + int sample_ratio) { + // computing sampled_masks + RoIAlignForward(gt_masks, rois, matches, total_out_el, + num_classes, gt_height, gt_width, mask_size, mask_size, + sample_ratio, num_rois, num_gtmasks, sampled_masks); + // computing mask_cls + int num_masks = batch_size * num_rois * num_classes; + int mask_vol = mask_size * mask_size; + for (int mask_idx = blockIdx.x; mask_idx < num_masks; mask_idx += gridDim.x) { + int cls_idx = mask_idx % num_classes; + int roi_idx = (mask_idx / num_classes) % num_rois; + int batch_idx = (mask_idx / num_classes / num_rois); + + DType* mask_cls_out = mask_cls + mask_idx * mask_vol; + + DType cls_target = cls_targets[batch_idx * num_rois + roi_idx]; + DType out_val = (cls_target == cls_idx); + for (int mask_pixel = threadIdx.x; mask_pixel < mask_vol; mask_pixel += blockDim.x) { + mask_cls_out[mask_pixel] = out_val; + } + } +} + +template<> +void MRCNNTargetRun(const MRCNNTargetParam& param, const std::vector &inputs, + const std::vector &outputs, mshadow::Stream *s) { + const int block_dim_size = kMaxThreadsPerBlock; + using namespace mxnet_op; + using mshadow::Tensor; + + auto stream = mshadow::Stream::GetStream(s); + + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + auto rois = inputs[mrcnn_index::kRoi].FlatToKD(s); + auto gt_masks = inputs[mrcnn_index::kGtMask].FlatToKD(s); + auto matches = inputs[mrcnn_index::kMatches].FlatTo2D(s); + auto cls_targets = inputs[mrcnn_index::kClasses].FlatTo2D(s); + + auto out_masks = outputs[mrcnn_index::kMask].FlatToKD(s); + auto out_mask_cls = outputs[mrcnn_index::kMaskClasses].FlatToKD(s); + + int batch_size = gt_masks.shape_[0]; + int num_gtmasks = gt_masks.shape_[1]; + int gt_height = gt_masks.shape_[2]; + int gt_width = gt_masks.shape_[3]; + + int num_el = outputs[mrcnn_index::kMask].Size(); + + dim3 dimGrid = dim3(CUDA_GET_BLOCKS(num_el)); + dim3 dimBlock = dim3(block_dim_size); + + MRCNNTargetKernel<<>> + (rois.dptr_, gt_masks.dptr_, matches.dptr_, cls_targets.dptr_, + out_masks.dptr_, out_mask_cls.dptr_, + num_el, batch_size, param.num_classes, param.num_rois, + num_gtmasks, gt_height, gt_width, param.mask_size, param.sample_ratio); + MSHADOW_CUDA_POST_KERNEL_CHECK(MRCNNTargetKernel); + }); +} + +DMLC_REGISTER_PARAMETER(MRCNNTargetParam); + +NNVM_REGISTER_OP(mrcnn_target) +.describe("Generate mask targets for Mask-RCNN.") +.set_num_inputs(4) +.set_num_outputs(2) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", MRCNNTargetShape) +.set_attr("FInferType", MRCNNTargetType) +.set_attr("FCompute", MRCNNTargetCompute) +.add_argument("rois", "NDArray-or-Symbol", "Bounding box coordinates, a 3D array") +.add_argument("gt_masks", "NDArray-or-Symbol", "Input masks of full image size, a 4D array") +.add_argument("matches", "NDArray-or-Symbol", "Index to a gt_mask, a 2D array") +.add_argument("cls_targets", "NDArray-or-Symbol", + "Value [0, num_class), excluding background class, a 2D array") +.add_arguments(MRCNNTargetParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 962e8687885b..0556f3aaf299 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -8639,6 +8639,63 @@ def test_rroi_align_value(sampling_ratio=-1): test_rroi_align_value() test_rroi_align_value(sampling_ratio=2) +@with_seed() +def test_op_mrcnn_target(): + if default_context().device_type != 'gpu': + return + + num_rois = 2 + num_classes = 4 + mask_size = 3 + ctx = mx.gpu(0) + # (B, N, 4) + rois = mx.nd.array([[[2.3, 4.3, 2.2, 3.3], + [3.5, 5.5, 0.9, 2.4]]], ctx=ctx) + gt_masks = mx.nd.arange(0, 4*32*32, ctx=ctx).reshape(1, 4, 32, 32) + + # (B, N) + matches = mx.nd.array([[2, 0]], ctx=ctx) + # (B, N) + cls_targets = mx.nd.array([[2, 1]], ctx=ctx) + + mask_targets, mask_cls = mx.nd.mrcnn_target(rois, gt_masks, matches, cls_targets, + num_rois=num_rois, + num_classes=num_classes, + mask_size=mask_size) + + # Ground truth outputs were generated with GluonCV's target generator + # gluoncv.model_zoo.mask_rcnn.MaskTargetGenerator(1, num_rois, num_classes, mask_size) + gt_mask_targets = mx.nd.array([[[[[2193.4 , 2193.7332 , 2194.0667 ], + [2204.0667 , 2204.4 , 2204.7334 ], + [2214.7334 , 2215.0667 , 2215.4 ]], + [[2193.4 , 2193.7332 , 2194.0667 ], + [2204.0667 , 2204.4 , 2204.7334 ], + [2214.7334 , 2215.0667 , 2215.4 ]], + [[2193.4 , 2193.7332 , 2194.0667 ], + [2204.0667 , 2204.4 , 2204.7334 ], + [2214.7334 , 2215.0667 , 2215.4 ]], + [[2193.4 , 2193.7332 , 2194.0667 ], + [2204.0667 , 2204.4 , 2204.7334 ], + [2214.7334 , 2215.0667 , 2215.4 ]]], + [[[ 185. , 185.33334, 185.66667], + [ 195.66667, 196.00002, 196.33334], + [ 206.33333, 206.66666, 207. ]], + [[ 185. , 185.33334, 185.66667], + [ 195.66667, 196.00002, 196.33334], + [ 206.33333, 206.66666, 207. ]], + [[ 185. , 185.33334, 185.66667], + [ 195.66667, 196.00002, 196.33334], + [ 206.33333, 206.66666, 207. ]], + [[ 185. , 185.33334, 185.66667], + [ 195.66667, 196.00002, 196.33334], + [ 206.33333, 206.66666, 207. ]]]]]) + + gt_mask_cls = mx.nd.array([[0,0,1,0], [0,1,0,0]]) + gt_mask_cls = gt_mask_cls.reshape(1,2,4,1,1).broadcast_axes(axis=(3,4), size=(3,3)) + + assert_almost_equal(mask_targets.asnumpy(), gt_mask_targets.asnumpy()) + assert_almost_equal(mask_cls.asnumpy(), gt_mask_cls.asnumpy()) + @with_seed() def test_diag():