This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add mask target generator operator for Mask-RCNN (#16268)
* Add mask target generator operator for Mask-RCNN Signed-off-by: Serge Panev <[email protected]> * Disable the unit test for CPU default ctx Signed-off-by: Serge Panev <[email protected]> * Address PR comments Signed-off-by: Serge Panev <[email protected]>
- Loading branch information
Showing
3 changed files
with
467 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* Copyright (c) 2019 by Contributors | ||
* \file mrcnn_target-inl.h | ||
* \brief Mask-RCNN target generator | ||
* \author Serge Panev | ||
*/ | ||
|
||
|
||
#ifndef MXNET_OPERATOR_CONTRIB_MRCNN_TARGET_INL_H_ | ||
#define MXNET_OPERATOR_CONTRIB_MRCNN_TARGET_INL_H_ | ||
|
||
#include <mxnet/operator.h> | ||
#include <vector> | ||
#include "../operator_common.h" | ||
#include "../mshadow_op.h" | ||
#include "../tensor/init_op.h" | ||
|
||
namespace mxnet { | ||
namespace op { | ||
|
||
namespace mrcnn_index { | ||
enum ROIAlignOpInputs {kRoi, kGtMask, kMatches, kClasses}; | ||
enum ROIAlignOpOutputs {kMask, kMaskClasses}; | ||
} // namespace mrcnn_index | ||
|
||
struct MRCNNTargetParam : public dmlc::Parameter<MRCNNTargetParam> { | ||
int num_rois; | ||
int num_classes; | ||
int mask_size; | ||
int sample_ratio; | ||
|
||
DMLC_DECLARE_PARAMETER(MRCNNTargetParam) { | ||
DMLC_DECLARE_FIELD(num_rois) | ||
.describe("Number of sampled RoIs."); | ||
DMLC_DECLARE_FIELD(num_classes) | ||
.describe("Number of classes."); | ||
DMLC_DECLARE_FIELD(mask_size) | ||
.describe("Size of the pooled masks."); | ||
DMLC_DECLARE_FIELD(sample_ratio).set_default(2) | ||
.describe("Sampling ratio of ROI align. Set to -1 to use adaptative size."); | ||
} | ||
}; | ||
|
||
inline bool MRCNNTargetShape(const NodeAttrs& attrs, | ||
std::vector<mxnet::TShape>* in_shape, | ||
std::vector<mxnet::TShape>* out_shape) { | ||
using namespace mshadow; | ||
const MRCNNTargetParam& param = nnvm::get<MRCNNTargetParam>(attrs.parsed); | ||
|
||
CHECK_EQ(in_shape->size(), 4) << "Input:[rois, gt_masks, matches, cls_targets]"; | ||
|
||
// (B, N, 4) | ||
mxnet::TShape tshape = in_shape->at(mrcnn_index::kRoi); | ||
CHECK_EQ(tshape.ndim(), 3) << "rois should be a 2D tensor of shape [batch, rois, 4]"; | ||
CHECK_EQ(tshape[2], 4) << "rois should be a 2D tensor of shape [batch, rois, 4]"; | ||
auto batch_size = tshape[0]; | ||
auto num_rois = tshape[1]; | ||
|
||
// (B, M, H, W) | ||
tshape = in_shape->at(mrcnn_index::kGtMask); | ||
CHECK_EQ(tshape.ndim(), 4) << "gt_masks should be a 4D tensor"; | ||
CHECK_EQ(tshape[0], batch_size) << " batch size should be the same for all the inputs."; | ||
|
||
// (B, N) | ||
tshape = in_shape->at(mrcnn_index::kMatches); | ||
CHECK_EQ(tshape.ndim(), 2) << "matches should be a 2D tensor"; | ||
CHECK_EQ(tshape[0], batch_size) << " batch size should be the same for all the inputs."; | ||
|
||
// (B, N) | ||
tshape = in_shape->at(mrcnn_index::kClasses); | ||
CHECK_EQ(tshape.ndim(), 2) << "matches should be a 2D tensor"; | ||
CHECK_EQ(tshape[0], batch_size) << " batch size should be the same for all the inputs."; | ||
|
||
// out: 2 * (B, N, C, MS, MS) | ||
auto oshape = Shape5(batch_size, num_rois, param.num_classes, param.mask_size, param.mask_size); | ||
out_shape->clear(); | ||
out_shape->push_back(oshape); | ||
out_shape->push_back(oshape); | ||
return true; | ||
} | ||
|
||
inline bool MRCNNTargetType(const NodeAttrs& attrs, | ||
std::vector<int>* in_type, | ||
std::vector<int>* out_type) { | ||
CHECK_EQ(in_type->size(), 4); | ||
int dtype = (*in_type)[1]; | ||
CHECK_NE(dtype, -1) << "Input must have specified type"; | ||
|
||
out_type->clear(); | ||
out_type->push_back(dtype); | ||
out_type->push_back(dtype); | ||
return true; | ||
} | ||
|
||
template<typename xpu> | ||
void MRCNNTargetRun(const MRCNNTargetParam& param, const std::vector<TBlob> &inputs, | ||
const std::vector<TBlob> &outputs, mshadow::Stream<xpu> *s); | ||
|
||
template<typename xpu> | ||
void MRCNNTargetCompute(const nnvm::NodeAttrs& attrs, | ||
const OpContext &ctx, | ||
const std::vector<TBlob> &inputs, | ||
const std::vector<OpReqType> &req, | ||
const std::vector<TBlob> &outputs) { | ||
auto s = ctx.get_stream<xpu>(); | ||
const auto& p = dmlc::get<MRCNNTargetParam>(attrs.parsed); | ||
MRCNNTargetRun<xpu>(p, inputs, outputs, s); | ||
} | ||
|
||
} // namespace op | ||
} // namespace mxnet | ||
|
||
#endif // MXNET_OPERATOR_CONTRIB_MRCNN_TARGET_INL_H_ |
Oops, something went wrong.