diff --git a/src/operator/contrib/bounding_box-inl.h b/src/operator/contrib/bounding_box-inl.h index 686f1666a310..010280e10852 100644 --- a/src/operator/contrib/bounding_box-inl.h +++ b/src/operator/contrib/bounding_box-inl.h @@ -787,6 +787,284 @@ void BipartiteMatchingBackward(const nnvm::NodeAttrs& attrs, }); } + +inline bool BoxEncodeShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_attrs, + mxnet::ShapeVector *out_attrs) { + CHECK_EQ(in_attrs->size(), 6U); + CHECK_EQ(out_attrs->size(), 2U); + mxnet::TShape& sshape = (*in_attrs)[0]; + mxnet::TShape& mshape = (*in_attrs)[1]; + mxnet::TShape& ashape = (*in_attrs)[2]; + mxnet::TShape& rshape = (*in_attrs)[3]; + + CHECK_EQ(sshape.ndim(), 2) + << "samples shape must have dim == 2, " + << sshape.ndim() << " provided"; + + CHECK_GE(mshape.ndim(), 2) + << "matches shape must have dim == 2, " + << mshape.ndim() << " provided"; + + CHECK_GE(ashape.ndim(), 3) + << "matches shape must have dim == 3, " + << ashape.ndim() << " provided"; + int ldim = ashape[ashape.ndim() - 1]; + CHECK_EQ(ldim, 4) + << "last dimension of anchors must be 4, " + << ldim << " provided"; + + CHECK_GE(rshape.ndim(), 3) + << "refs shape must have dim == 3, " + << ashape.ndim() << " provided"; + ldim = rshape[rshape.ndim() - 1]; + CHECK_EQ(ldim, 4) + << "last dimension of anchors must be 4, " + << ldim << " provided"; + + // asign input shape + SHAPE_ASSIGN_CHECK(*in_attrs, 4, mshadow::Shape1(4)); + SHAPE_ASSIGN_CHECK(*in_attrs, 5, mshadow::Shape1(4)); + + // assign output shape + mxnet::TShape oshape = ashape; + SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); + SHAPE_ASSIGN_CHECK(*out_attrs, 1, oshape); + return shape_is_known(oshape); +} + +struct box_encode { + template + MSHADOW_XINLINE static void Map(index_t i, DType *out_targets, DType *out_masks, + const DType *samples, const DType *matches, + const DType *anchors, const DType *refs, + const DType *means, const DType *stds, + const int m, const int n) { + index_t j = i / n; + index_t match = matches[i]; + // xmin: 0, ymin:1, xmax: 2, ymax: 3 + // x:0, y:1, w:2, h:3 + index_t ref_index = (j * m + match) * 4; + DType ref_xmin = refs[ref_index + 0]; + DType ref_ymin = refs[ref_index + 1]; + DType ref_width = refs[ref_index + 2] - ref_xmin; + DType ref_height = refs[ref_index + 3] - ref_ymin; + DType ref_x = ref_xmin + ref_width * 0.5; + DType ref_y = ref_ymin + ref_height * 0.5; + index_t a_index = i * 4; + DType a_xmin = anchors[a_index + 0]; + DType a_ymin = anchors[a_index + 1]; + DType a_width = anchors[a_index + 2] - a_xmin; + DType a_height = anchors[a_index + 3] - a_ymin; + DType a_x = a_xmin + a_width * 0.5; + DType a_y = a_ymin + a_height * 0.5; + DType valid = samples[i] > 0.5 ? 1.0 : 0.0; + out_masks[a_index + 0] = valid; + out_masks[a_index + 1] = valid; + out_masks[a_index + 2] = valid; + out_masks[a_index + 3] = valid; + out_targets[a_index + 0] = valid > static_cast(0.5) ? + ((ref_x - a_x) / a_width - static_cast(means[0])) / + static_cast(stds[0]) : static_cast(0.0); + out_targets[a_index + 1] = valid > static_cast(0.5) ? + ((ref_y - a_y) / a_height - static_cast(means[1])) / + static_cast(stds[1]) : static_cast(0.0); + out_targets[a_index + 2] = valid > static_cast(0.5) ? + (log(ref_width / a_width) - static_cast(means[2])) / + static_cast(stds[2]) : static_cast(0.0); + out_targets[a_index + 3] = valid > static_cast(0.5) ? + (log(ref_height / a_height) - static_cast(means[3])) / + static_cast(stds[3]) : static_cast(0.0); + } +}; + +template +void BoxEncodeForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mxnet_op; + CHECK_EQ(inputs.size(), 6U); + CHECK_EQ(outputs.size(), 2U); + Stream *s = ctx.get_stream(); + // samples, matches, anchors, refs, means, stds + mxnet::TShape anchor_shape = inputs[2].shape_; + int loop_size = anchor_shape.ProdShape(0, 2); + int b = anchor_shape[0]; + int n = anchor_shape[1]; + int m = inputs[3].shape_[1]; + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + Tensor samples = inputs[0] + .get_with_shape(Shape2(b, n), s); + Tensor matches = inputs[1] + .get_with_shape(Shape2(b, n), s); + Tensor anchors = inputs[2] + .get_with_shape(Shape3(b, n, 4), s); + Tensor refs = inputs[3] + .get_with_shape(Shape3(b, m, 4), s); + Tensor means = inputs[4] + .get_with_shape(Shape1(4), s); + Tensor stds = inputs[5] + .get_with_shape(Shape1(4), s); + Tensor out_targets = outputs[0] + .get_with_shape(Shape3(b, n, 4), s); + Tensor out_masks = outputs[1] + .get_with_shape(Shape3(b, n, 4), s); + + Kernel::Launch(s, loop_size, out_targets.dptr_, + out_masks.dptr_, samples.dptr_, matches.dptr_, anchors.dptr_, + refs.dptr_, means.dptr_, stds.dptr_, m, n); + }); +} + +struct BoxDecodeParam : public dmlc::Parameter { + float std0; + float std1; + float std2; + float std3; + float clip; + int format; + DMLC_DECLARE_PARAMETER(BoxDecodeParam) { + DMLC_DECLARE_FIELD(std0).set_default(1.0) + .describe("value to be divided from the 1st encoded values"); + DMLC_DECLARE_FIELD(std1).set_default(1.0) + .describe("value to be divided from the 2nd encoded values"); + DMLC_DECLARE_FIELD(std2).set_default(1.0) + .describe("value to be divided from the 3rd encoded values"); + DMLC_DECLARE_FIELD(std3).set_default(1.0) + .describe("value to be divided from the 4th encoded values"); + DMLC_DECLARE_FIELD(clip).set_default(-1.0) + .describe("If larger than 0, bounding box target will be clipped to this value."); + DMLC_DECLARE_FIELD(format).set_default(box_common_enum::kCenter) + .add_enum("corner", box_common_enum::kCorner) + .add_enum("center", box_common_enum::kCenter) + .describe("The box encoding type. \n" + " \"corner\" means boxes are encoded as [xmin, ymin, xmax, ymax]," + " \"center\" means boxes are encodes as [x, y, width, height]."); + } +}; // BoxDecodeParam + +inline bool BoxDecodeShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_attrs, + mxnet::ShapeVector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + mxnet::TShape& dshape = (*in_attrs)[0]; + mxnet::TShape& ashape = (*in_attrs)[1]; + + CHECK_EQ(dshape.ndim(), 3) + << "data shape must have dim == 3, " + << dshape.ndim() << " provided"; + int ldim = dshape[dshape.ndim() - 1]; + CHECK_EQ(ldim, 4) + << "last dimension of data must be 4, " + << ldim << " provided"; + + CHECK_GE(ashape.ndim(), 3) + << "anchors shape must have dim == 3, " + << ashape.ndim() << " provided"; + ldim = ashape[ashape.ndim() - 1]; + CHECK_EQ(ldim, 4) + << "last dimension of anchors must be 4, " + << ldim << " provided"; + + // assign output shape + mxnet::TShape oshape = dshape; + SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); + return shape_is_known(oshape); +} + +template +struct box_decode { + template + MSHADOW_XINLINE static void Map(index_t i, DType *out, const DType *x, + const DType *anchors, const DType std0, + const DType std1, const DType std2, + const DType std3, const DType clip, + const int n) { + index_t index = i * 4; + index_t a_index = (i % n) * 4; + DType a_x = anchors[a_index + 0]; + DType a_y = anchors[a_index + 1]; + DType a_width = anchors[a_index + 2]; + DType a_height = anchors[a_index + 3]; + if (box_common_enum::kCorner == anchor_encode) { + // a_x = xmin, a_y = ymin, a_width = xmax, a_height = ymax + a_width = a_width - a_x; + a_height = a_height - a_y; + a_x = a_x + a_width * 0.5; + a_y = a_y + a_height * 0.5; + } + DType ox = x[index + 0] * std0 * a_width + a_x; + DType oy = x[index + 1] * std1 * a_height + a_y; + DType dw = x[index + 2] * std2; + DType dh = x[index + 3] * std3; + if (has_clip) { + dw = dw < clip ? dw : clip; + dh = dh < clip ? dh : clip; + } + dw = exp(dw); + dh = exp(dh); + DType ow = dw * a_width * 0.5; + DType oh = dh * a_height * 0.5; + out[index + 0] = ox - ow; + out[index + 1] = oy - oh; + out[index + 2] = ox + ow; + out[index + 3] = oy + oh; + } +}; + +template +void BoxDecodeForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mxnet_op; + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + Stream *s = ctx.get_stream(); + mxnet::TShape x_shape = inputs[0].shape_; + int b = x_shape[0]; + int n = x_shape[1]; + int loop_size = b * n; + const BoxDecodeParam& param = nnvm::get(attrs.parsed); + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + Tensor data = inputs[0] + .get_with_shape(Shape3(b, n, 4), s); + Tensor anchors = inputs[1] + .get_with_shape(Shape3(1, n, 4), s); + Tensor out = outputs[0] + .get_with_shape(Shape3(b, n, 4), s); + if (box_common_enum::kCorner == param.format && param.clip > 0.0) { + Kernel, xpu>::Launch(s, loop_size, + out.dptr_, data.dptr_, anchors.dptr_, static_cast(param.std0), + static_cast(param.std1), static_cast(param.std2), + static_cast(param.std3), static_cast(param.clip), n); + } else if (box_common_enum::kCenter == param.format && param.clip > 0.0) { + Kernel, xpu>::Launch(s, loop_size, + out.dptr_, data.dptr_, anchors.dptr_, static_cast(param.std0), + static_cast(param.std1), static_cast(param.std2), + static_cast(param.std3), static_cast(param.clip), n); + } else if (box_common_enum::kCorner == param.format && param.clip <= 0.0) { + Kernel, xpu>::Launch(s, loop_size, + out.dptr_, data.dptr_, anchors.dptr_, static_cast(param.std0), + static_cast(param.std1), static_cast(param.std2), + static_cast(param.std3), static_cast(param.clip), n); + } else { + Kernel, xpu>::Launch(s, loop_size, + out.dptr_, data.dptr_, anchors.dptr_, static_cast(param.std0), + static_cast(param.std1), static_cast(param.std2), + static_cast(param.std3), static_cast(param.clip), n); + } + }); +} + } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/bounding_box.cc b/src/operator/contrib/bounding_box.cc index d682fafec092..62b7c2e0bf4b 100644 --- a/src/operator/contrib/bounding_box.cc +++ b/src/operator/contrib/bounding_box.cc @@ -32,6 +32,7 @@ namespace op { DMLC_REGISTER_PARAMETER(BoxNMSParam); DMLC_REGISTER_PARAMETER(BoxOverlapParam); DMLC_REGISTER_PARAMETER(BipartiteMatchingParam); +DMLC_REGISTER_PARAMETER(BoxDecodeParam); NNVM_REGISTER_OP(_contrib_box_nms) .add_alias("_contrib_box_non_maximum_suppression") @@ -201,5 +202,47 @@ NNVM_REGISTER_OP(_backward_contrib_bipartite_matching) .set_attr("FCompute", BipartiteMatchingBackward) .add_arguments(BipartiteMatchingParam::__FIELDS__()); +NNVM_REGISTER_OP(_contrib_box_encode) +.describe(R"doc(Encode bounding boxes training target with normalized center offsets. + Input bounding boxes are using corner type: `x_{min}, y_{min}, x_{max}, y_{max}`.) array +)doc" ADD_FILELINE) +.set_num_inputs(6) +.set_num_outputs(2) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"samples", "matches", "anchors", "refs", "means", "stds"}; + }) +.set_attr("FInferShape", BoxEncodeShape) +.set_attr("FInferType", ElemwiseType<6, 2>) +.set_attr("FCompute", BoxEncodeForward) +.set_attr("FGradient", MakeZeroGradNodes) +.add_argument("samples", "NDArray-or-Symbol", "(B, N) value +1 (positive), -1 (negative), " + "0 (ignore)") +.add_argument("matches", "NDArray-or-Symbol", "(B, N) value range [0, M)") +.add_argument("anchors", "NDArray-or-Symbol", "(B, N, 4) encoded in corner") +.add_argument("refs", "NDArray-or-Symbol", "(B, M, 4) encoded in corner") +.add_argument("means", "NDArray-or-Symbol", "(4,) Mean value to be subtracted from encoded values") +.add_argument("stds", "NDArray-or-Symbol", "(4,) Std value to be divided from encoded values"); + +NNVM_REGISTER_OP(_contrib_box_decode) +.describe(R"doc(Decode bounding boxes training target with normalized center offsets. + Input bounding boxes are using corner type: `x_{min}, y_{min}, x_{max}, y_{max}` + or center type: `x, y, width, height.) array +)doc" ADD_FILELINE) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data", "anchors"}; + }) +.set_attr("FInferShape", BoxDecodeShape) +.set_attr("FInferType", ElemwiseType<2, 1>) +.set_attr("FCompute", BoxDecodeForward) +.set_attr("FGradient", MakeZeroGradNodes) +.add_argument("data", "NDArray-or-Symbol", "(B, N, 4) predicted bbox offset") +.add_argument("anchors", "NDArray-or-Symbol", "(1, N, 4) encoded in corner or center") +.add_arguments(BoxDecodeParam::__FIELDS__()); + } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/bounding_box.cu b/src/operator/contrib/bounding_box.cu index 2677d2f79478..b20c570ea417 100644 --- a/src/operator/contrib/bounding_box.cu +++ b/src/operator/contrib/bounding_box.cu @@ -47,5 +47,12 @@ NNVM_REGISTER_OP(_contrib_bipartite_matching) NNVM_REGISTER_OP(_backward_contrib_bipartite_matching) .set_attr("FCompute", BipartiteMatchingBackward); + +NNVM_REGISTER_OP(_contrib_box_encode) +.set_attr("FCompute", BoxEncodeForward); + +NNVM_REGISTER_OP(_contrib_box_decode) +.set_attr("FCompute", BoxDecodeForward); + } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/roi_align.cc b/src/operator/contrib/roi_align.cc index 53ddba02bc7b..ee91561a6818 100644 --- a/src/operator/contrib/roi_align.cc +++ b/src/operator/contrib/roi_align.cc @@ -167,6 +167,10 @@ num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) int roi_batch_ind = 0; if (roi_cols == 5) { roi_batch_ind = offset_bottom_rois[0]; + if (roi_batch_ind < 0) { + top_data[n] = 0; + continue; + } offset_bottom_rois++; } @@ -340,6 +344,7 @@ void ROIAlignBackward( int roi_batch_ind = 0; if (rois_cols == 5) { roi_batch_ind = offset_bottom_rois[0]; + if (roi_batch_ind < 0) continue; offset_bottom_rois++; } @@ -520,7 +525,8 @@ NNVM_REGISTER_OP(_contrib_ROIAlign) .describe(R"code( This operator takes a 4D feature map as an input array and region proposals as `rois`, then align the feature map over sub-regions of input and produces a fixed-sized output array. -This operator is typically used in Faster R-CNN & Mask R-CNN networks. +This operator is typically used in Faster R-CNN & Mask R-CNN networks. If roi batchid is less +than 0, it will be ignored, and the corresponding output will be set to 0. Different from ROI pooling, ROI Align removes the harsh quantization, properly aligning the extracted features with the input. RoIAlign computes the value of each sampling point @@ -594,7 +600,8 @@ He, Kaiming, et al. "Mask R-CNN." ICCV, 2017 return MakeGradNode("_backward_ROIAlign", n, heads, n->attrs.dict); }) .add_argument("data", "NDArray-or-Symbol", "Input data to the pooling operator, a 4D Feature maps") -.add_argument("rois", "NDArray-or-Symbol", "Bounding box coordinates, a 2D array") +.add_argument("rois", "NDArray-or-Symbol", "Bounding box coordinates, a 2D array, " + "if batchid is less than 0, it will be ignored.") .add_arguments(ROIAlignParam::__FIELDS__()); diff --git a/src/operator/contrib/roi_align.cu b/src/operator/contrib/roi_align.cu index 38b461d5f58c..1d90fbff5193 100644 --- a/src/operator/contrib/roi_align.cu +++ b/src/operator/contrib/roi_align.cu @@ -130,6 +130,11 @@ __global__ void RoIAlignForwardKernel( const T* offset_bottom_rois = bottom_rois + n * 5; int roi_batch_ind = offset_bottom_rois[0]; + if (roi_batch_ind < 0) { + top_data[index] = 0.; + continue; + } + // Do not using rounding; this implementation detail is critical T roi_start_w = offset_bottom_rois[1] * spatial_scale; T roi_start_h = offset_bottom_rois[2] * spatial_scale; @@ -268,6 +273,7 @@ __global__ void RoIAlignBackwardKernel( const T* offset_bottom_rois = bottom_rois + n * 5; int roi_batch_ind = offset_bottom_rois[0]; + if (roi_batch_ind < 0) continue; // Do not using rounding; this implementation detail is critical T roi_start_w = offset_bottom_rois[1] * spatial_scale; diff --git a/src/operator/tensor/amp_cast.h b/src/operator/tensor/amp_cast.h index a722b417c715..be7d400ca153 100644 --- a/src/operator/tensor/amp_cast.h +++ b/src/operator/tensor/amp_cast.h @@ -48,10 +48,13 @@ struct AMPCastParam : public dmlc::Parameter { struct AMPMultiCastParam : public dmlc::Parameter { int num_outputs; + bool cast_narrow; DMLC_DECLARE_PARAMETER(AMPMultiCastParam) { DMLC_DECLARE_FIELD(num_outputs) .describe("Number of input/output pairs to be casted to the widest type."); + DMLC_DECLARE_FIELD(cast_narrow).set_default(false) + .describe("Whether to cast to the narrowest type"); } }; @@ -80,10 +83,12 @@ inline bool AMPMultiCastType(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), param.num_outputs); CHECK_EQ(out_attrs->size(), param.num_outputs); bool ret = true; - int widest_type = kFloat16; + int widest_type = param.cast_narrow ? kFloat32 : kFloat16; for (int i = 0; i < param.num_outputs; ++i) { - if ((*in_attrs)[i] == kFloat32 || (*out_attrs)[i] == kFloat32) { + if (!param.cast_narrow && ((*in_attrs)[i] == kFloat32 || (*out_attrs)[i] == kFloat32)) { widest_type = kFloat32; + } else if (param.cast_narrow &&((*in_attrs)[i] == kFloat16 || (*out_attrs)[i] == kFloat16)) { + widest_type = kFloat16; } } for (int i = 0; i < param.num_outputs; ++i) { diff --git a/tests/python/unittest/test_contrib_operator.py b/tests/python/unittest/test_contrib_operator.py index 813a6b092b9c..43ba9d8b6318 100644 --- a/tests/python/unittest/test_contrib_operator.py +++ b/tests/python/unittest/test_contrib_operator.py @@ -333,6 +333,24 @@ def test_multibox_prior_op(): boxes = Y.reshape((h, w, 5, 4)) assert_allclose(boxes.asnumpy()[250, 250, 0, :], np.array([-0.948249, 0.362671, 1.636436, 0.530377]), atol=1e-5, rtol=1e-5) +def test_box_encode_op(): + anchors = mx.nd.array([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]).reshape((1, -1, 4)) + refs = mx.nd.array([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]).reshape((1, -1, 4)) + samples = mx.nd.array([[0, 1]]) + matches = mx.nd.array([[0, 1]]) + means = mx.nd.array([0.0, 0.0, 0.0, 0.0]) + stds = mx.nd.array([0.1, 0.1, 0.2, 0.2]) + Y, mask = mx.nd.contrib.box_encode(samples, matches, anchors, refs, means, stds) + assert_allclose(Y.asnumpy(), np.zeros((1, 2, 4)), atol=1e-5, rtol=1e-5) + assert_allclose(mask.asnumpy(), np.array([[[0., 0., 0., 0.], [1., 1., 1., 1.]]]), atol=1e-5, rtol=1e-5) + +def test_box_decode_op(): + data = mx.nd.array([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]).reshape((1, -1, 4)) + anchors = mx.nd.array([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]).reshape((1, -1, 4)) + Y = mx.nd.contrib.box_decode(data, anchors, .1, .1, .2, .2) + assert_allclose(Y.asnumpy(), np.array([[[-0.0562755, -0.00865743, 0.26227552, 0.42465743], \ + [0.13240421, 0.17859563, 0.93759584, 1.1174043 ]]]), atol=1e-5, rtol=1e-5) + if __name__ == '__main__': import nose nose.runmodule()