Skip to content

Commit

Permalink
Add cpu implementation for Deformable PSROIPooling (apache#14886)
Browse files Browse the repository at this point in the history
* add cpu deformable_psroi_pooling forward

* add cpu deformable_psroi_pooling backward

* add consistency checks

* fix nullptr

* fix code style

* fix lint

* fix code style

* update to index_t

* fix lint

* fix compile
  • Loading branch information
arcadiaphy authored and haohuw committed Jun 23, 2019
1 parent 628b83d commit ca6920e
Show file tree
Hide file tree
Showing 4 changed files with 534 additions and 268 deletions.
50 changes: 25 additions & 25 deletions src/operator/contrib/deformable_psroi_pooling-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ namespace deformablepsroipool {
struct DeformablePSROIPoolingParam : public dmlc::Parameter<DeformablePSROIPoolingParam> {
// mxnet::TShape pooled_size;
float spatial_scale;
int output_dim;
int group_size;
int pooled_size;
int part_size;
int sample_per_part;
index_t output_dim;
index_t group_size;
index_t pooled_size;
index_t part_size;
index_t sample_per_part;
float trans_std;
bool no_trans;
DMLC_DECLARE_PARAMETER(DeformablePSROIPoolingParam) {
Expand All @@ -82,10 +82,10 @@ class DeformablePSROIPoolingOp : public Operator {
}

virtual void Forward(const OpContext &ctx,
const std::vector<TBlob> &in_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &aux_args) {
const std::vector<TBlob> &in_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &aux_args) {
using namespace mshadow;
size_t in_expected = param_.no_trans? 2 : 3;
size_t out_expected = 2;
Expand Down Expand Up @@ -119,12 +119,12 @@ class DeformablePSROIPoolingOp : public Operator {
}

virtual void Backward(const OpContext &ctx,
const std::vector<TBlob> &out_grad,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &in_grad,
const std::vector<TBlob> &aux_args) {
const std::vector<TBlob> &out_grad,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &in_grad,
const std::vector<TBlob> &aux_args) {
using namespace mshadow;
size_t in_expected = param_.no_trans ? 2 : 3;
size_t out_expected = 2;
Expand Down Expand Up @@ -216,8 +216,8 @@ class DeformablePSROIPoolingProp : public OperatorProperty {
}

bool InferShape(mxnet::ShapeVector *in_shape,
mxnet::ShapeVector *out_shape,
mxnet::ShapeVector *aux_shape) const override {
mxnet::ShapeVector *out_shape,
mxnet::ShapeVector *aux_shape) const override {
using namespace mshadow;
if (param_.no_trans) {
CHECK_EQ(in_shape->size(), 2) << "Input:[data, rois]";
Expand Down Expand Up @@ -248,8 +248,8 @@ class DeformablePSROIPoolingProp : public OperatorProperty {
}

bool InferType(std::vector<int> *in_type,
std::vector<int> *out_type,
std::vector<int> *aux_type) const override {
std::vector<int> *out_type,
std::vector<int> *aux_type) const override {
CHECK_GE(in_type->size(), 2);
int dtype = (*in_type)[0];
CHECK_EQ(dtype, (*in_type)[1]);
Expand All @@ -272,10 +272,9 @@ class DeformablePSROIPoolingProp : public OperatorProperty {
}

// decalre dependency and inplace optimization options
std::vector<int> DeclareBackwardDependency(
const std::vector<int> &out_grad,
const std::vector<int> &in_data,
const std::vector<int> &out_data) const override {
std::vector<int> DeclareBackwardDependency(const std::vector<int> &out_grad,
const std::vector<int> &in_data,
const std::vector<int> &out_data) const override {
if (param_.no_trans) {
return{ out_grad[deformablepsroipool::kOut], in_data[deformablepsroipool::kData],
in_data[deformablepsroipool::kBox], out_data[deformablepsroipool::kTopCount] };
Expand All @@ -292,8 +291,9 @@ class DeformablePSROIPoolingProp : public OperatorProperty {
return NULL;
}

Operator* CreateOperatorEx(Context ctx, mxnet::ShapeVector *in_shape,
std::vector<int> *in_type) const override;
Operator* CreateOperatorEx(Context ctx,
mxnet::ShapeVector *in_shape,
std::vector<int> *in_type) const override;


private:
Expand Down
Loading

0 comments on commit ca6920e

Please sign in to comment.