forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
dc027a4
commit 17c6819
Showing
3 changed files
with
954 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,250 @@ | ||
/*! | ||
* Copyright (c) 2015 by Contributors | ||
* \file Correlation.cu | ||
* \brief Correlation pooling operator | ||
* \author Xu Dong | ||
*/ | ||
|
||
#ifndef MXNET_OPERATOR_ROI_POOLING_INL_H_ | ||
#define MXNET_OPERATOR_ROI_POOLING_INL_H_ | ||
|
||
#include <dmlc/logging.h> | ||
#include <dmlc/parameter.h> | ||
#include <mxnet/operator.h> | ||
#include <map> | ||
#include <vector> | ||
#include <string> | ||
#include <utility> | ||
#include "./mshadow_op.h" | ||
#include "./operator_common.h" | ||
|
||
namespace mxnet { | ||
namespace op { | ||
|
||
// Declare enumeration of input order to make code more intuitive. | ||
// These enums are only visible within this header | ||
namespace Correlation { | ||
enum CorrelationOpInputs{kData1, kData2}; | ||
enum CorrelationOpOutputs{kOut,kTemp1,kTemp2}; | ||
} // namespace correlation | ||
|
||
struct CorrelationParam : public dmlc::Parameter<CorrelationParam> { | ||
uint32_t max_displacement; | ||
uint32_t kernel_size; | ||
uint32_t pad_size; | ||
uint32_t stride1; | ||
uint32_t stride2; | ||
bool is_multiply; | ||
DMLC_DECLARE_PARAMETER(CorrelationParam) { | ||
DMLC_DECLARE_FIELD(kernel_size).set_default(1).describe("kernel size for Correlation"); | ||
DMLC_DECLARE_FIELD(max_displacement).set_default(1).describe("Max displacement of Correlation "); | ||
DMLC_DECLARE_FIELD(stride1).set_default(1).describe("stride between Correlation"); | ||
DMLC_DECLARE_FIELD(stride2).set_default(1).describe("stride within neighbourhood"); | ||
DMLC_DECLARE_FIELD(pad_size).set_default(0).describe("pad for Correlation"); | ||
DMLC_DECLARE_FIELD(is_multiply).set_default(true).describe("operation type is either multiplication or subduction"); | ||
} | ||
|
||
}; | ||
|
||
template<typename xpu> | ||
class CorrelationOp : public Operator { | ||
public: | ||
explicit CorrelationOp(CorrelationParam param) { | ||
this->param_ = param; | ||
} | ||
|
||
virtual void Forward( const OpContext &ctx, | ||
const std::vector<TBlob> &in_data, | ||
const std::vector<OpReqType> &req, | ||
const std::vector<TBlob> &out_data, | ||
const std::vector<TBlob> &aux_args) { | ||
using namespace mshadow; | ||
|
||
CHECK_EQ(in_data.size(), 2); | ||
CHECK_EQ(out_data.size(), 3); | ||
Stream<xpu> *s = ctx.get_stream<xpu>(); | ||
|
||
Tensor<xpu, 4> data1 = in_data[Correlation::kData1].get<xpu, 4, real_t>(s); | ||
Tensor<xpu, 4> data2 = in_data[Correlation::kData2].get<xpu, 4, real_t>(s); | ||
Tensor<xpu, 4> out = out_data[Correlation::kOut].get<xpu, 4, real_t>(s); | ||
Tensor<xpu, 4> tmp1 = out_data[Correlation::kTemp1].get<xpu,4,real_t>(s); | ||
Tensor<xpu, 4> tmp2 = out_data[Correlation::kTemp2].get<xpu,4,real_t>(s); | ||
|
||
tmp1 = 0.0f; | ||
tmp2 = 0.0f; | ||
|
||
CHECK_EQ(data1.CheckContiguous(), true); | ||
CHECK_EQ(data2.CheckContiguous(), true); | ||
CHECK_EQ(out.CheckContiguous(), true); | ||
CHECK_EQ(tmp1.CheckContiguous(), true); | ||
CHECK_EQ(tmp2.CheckContiguous(), true); | ||
|
||
paddedbottomheight = data1.shape_[2] + 2*param_.pad_size; | ||
paddedbottomwidth = data1.shape_[3] + 2*param_.pad_size; | ||
kernel_radius_ = (param_.kernel_size -1 )/2; | ||
border_size_ = param_.max_displacement + kernel_radius_; | ||
stride1 = param_.stride1; | ||
stride2 = param_.stride2; | ||
top_width_ = ceil((float)(paddedbottomwidth - border_size_*2) / (float)stride1); | ||
top_height_ = ceil((float)(paddedbottomheight - border_size_*2) / (float)stride1); | ||
neighborhood_grid_radius_ = param_.max_displacement / stride2; | ||
neighborhood_grid_width_ = neighborhood_grid_radius_ * 2 + 1; | ||
top_channels_ = neighborhood_grid_width_ * neighborhood_grid_width_; | ||
|
||
num = data1.shape_[0]; | ||
channels = data1.shape_[1]; | ||
height = data1.shape_[2]; | ||
width = data1.shape_[3]; | ||
|
||
CorrelationForward(out, data1, data2, tmp1,tmp2,top_channels_,top_height_,top_width_, | ||
param_.pad_size,param_.is_multiply,param_.max_displacement,param_.kernel_size,neighborhood_grid_radius_, | ||
neighborhood_grid_width_,kernel_radius_,param_.stride1,param_.stride2); | ||
} | ||
|
||
virtual void Backward(const OpContext &ctx, | ||
const std::vector<TBlob> &out_grad, | ||
const std::vector<TBlob> &in_data, | ||
const std::vector<TBlob> &out_data, | ||
const std::vector<OpReqType> &req, | ||
const std::vector<TBlob> &in_grad, | ||
const std::vector<TBlob> &aux_args) { | ||
using namespace mshadow; | ||
Stream<xpu> *s = ctx.get_stream<xpu>(); | ||
Tensor<xpu, 4> grad_data1 = in_grad[Correlation::kData1].get<xpu,4,real_t>(s); | ||
Tensor<xpu, 4> grad_data2 = in_grad[Correlation::kData2].get<xpu,4,real_t>(s); | ||
Tensor<xpu, 4> out_g = out_grad[Correlation::kOut].get<xpu, 4, real_t>(s); | ||
Tensor<xpu, 4> tmp1 = out_data[Correlation::kTemp1].get<xpu,4,real_t>(s); | ||
Tensor<xpu, 4> tmp2 = out_data[Correlation::kTemp2].get<xpu,4,real_t>(s); | ||
|
||
CHECK_EQ(grad_data1.CheckContiguous(), true); | ||
CHECK_EQ(grad_data2.CheckContiguous(), true); | ||
CHECK_EQ(out_g.CheckContiguous(), true); | ||
CHECK_EQ(tmp1.CheckContiguous(), true); | ||
CHECK_EQ(tmp2.CheckContiguous(), true); | ||
|
||
CorrelationBackward(out_g,grad_data1,grad_data2,tmp1,tmp2,top_channels_,top_height_,top_width_,param_.pad_size,param_.is_multiply, | ||
param_.max_displacement,param_.kernel_size,neighborhood_grid_radius_,neighborhood_grid_width_,kernel_radius_,param_.stride1,param_.stride2,num,channels,height, width); | ||
} | ||
private: | ||
CorrelationParam param_; | ||
int paddedbottomheight; | ||
int paddedbottomwidth; | ||
uint32_t kernel_radius_; | ||
uint32_t border_size_; | ||
uint32_t stride1 ; | ||
uint32_t stride2 ; | ||
uint32_t top_width_; | ||
uint32_t top_height_; | ||
uint32_t neighborhood_grid_radius_; | ||
uint32_t neighborhood_grid_width_ ; | ||
uint32_t top_channels_ ; | ||
int num; | ||
int channels; | ||
int height; | ||
int width; | ||
}; // class CorrelationOp | ||
|
||
|
||
// Decalre Factory function | ||
template<typename xpu> | ||
Operator* CreateOp(CorrelationParam param); | ||
|
||
#if DMLC_USE_CXX11 | ||
class CorrelationProp : public OperatorProperty { | ||
public: | ||
std::vector<std::string> ListArguments() const override { | ||
return {"data1", "data2"}; | ||
} | ||
|
||
std::vector<std::string> ListOutputs() const override { | ||
return {"output","tmp1","tmp2"}; | ||
} | ||
|
||
int NumOutputs() const override { | ||
return 3; | ||
} | ||
|
||
int NumVisibleOutputs() const override { | ||
return 1; | ||
} | ||
|
||
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override { | ||
param_.Init(kwargs); | ||
} | ||
|
||
std::map<std::string, std::string> GetParams() const override { | ||
return param_.__DICT__(); | ||
} | ||
bool InferShape(std::vector<TShape> *in_shape, | ||
std::vector<TShape> *out_shape, | ||
std::vector<TShape> *aux_shape) const override { | ||
|
||
using namespace mshadow; | ||
CHECK_EQ(in_shape->size(), 2) << "Input:[data1, data2]"; | ||
TShape dshape1 = in_shape->at(Correlation::kData1); | ||
TShape dshape2 = in_shape->at(Correlation::kData2); | ||
CHECK_EQ(dshape1.ndim(), 4) << "data should be a 4D tensor"; | ||
CHECK_EQ(dshape2.ndim(), 4) << "data should be a 4D tensor"; | ||
|
||
int paddedbottomheight; | ||
int paddedbottomwidth; | ||
uint32_t kernel_radius_; | ||
uint32_t stride1 ; | ||
uint32_t stride2 ; | ||
uint32_t top_width_; | ||
uint32_t top_height_; | ||
uint32_t neighborhood_grid_radius_; | ||
uint32_t neighborhood_grid_width_ ; | ||
uint32_t top_channels_ ; | ||
uint32_t border_size_; | ||
|
||
paddedbottomheight = dshape1[2] + 2*param_.pad_size; | ||
paddedbottomwidth = dshape1[3] + 2*param_.pad_size; | ||
kernel_radius_ = (param_.kernel_size -1 )/2; | ||
border_size_ = param_.max_displacement + kernel_radius_; | ||
stride1 = param_.stride1; | ||
stride2 = param_.stride2; | ||
top_width_ = ceil((float)(paddedbottomwidth - border_size_*2) / (float)stride1); | ||
top_height_ = ceil((float)(paddedbottomheight - border_size_*2) / (float)stride1); | ||
neighborhood_grid_radius_ = param_.max_displacement / stride2; | ||
neighborhood_grid_width_ = neighborhood_grid_radius_ * 2 + 1; | ||
top_channels_ = neighborhood_grid_width_ * neighborhood_grid_width_; | ||
|
||
CHECK_GE(top_width_, 1) << "Correlation cannot be done with current settings. Neighborhood and kernel don't fit in blob"; | ||
CHECK_GE(top_height_, 1) << "Correlation cannot be done with current settings. Neighborhood and kernel don't fit in blob"; | ||
|
||
out_shape->clear(); | ||
out_shape->push_back(Shape4(dshape1[0],top_channels_,top_height_,top_width_)); | ||
out_shape->push_back(Shape4(dshape1[0],paddedbottomheight, paddedbottomwidth,dshape1[1])); | ||
out_shape->push_back(Shape4(dshape1[0],paddedbottomheight, paddedbottomwidth,dshape1[1])); | ||
return true; | ||
} | ||
|
||
OperatorProperty* Copy() const override { | ||
CorrelationProp* Correlation_sym = new CorrelationProp(); | ||
Correlation_sym->param_ = this->param_; | ||
return Correlation_sym; | ||
} | ||
|
||
std::string TypeString() const override { | ||
return "Correlation"; | ||
} | ||
|
||
// decalre dependency and inplace optimization options | ||
std::vector<int> DeclareBackwardDependency( | ||
const std::vector<int> &out_grad, | ||
const std::vector<int> &in_data, | ||
const std::vector<int> &out_data) const override | ||
{ | ||
return {out_grad[Correlation::kOut], out_data[Correlation::kTemp1], out_data[Correlation::kTemp2]}; | ||
} | ||
|
||
Operator* CreateOperator(Context ctx) const override; | ||
|
||
private: | ||
CorrelationParam param_; | ||
}; // class CorrelationProp | ||
#endif | ||
} // namespace op | ||
} // namespace mxnet | ||
#endif // MXNET_OPERATOR_Correlation_INL_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
/*! | ||
* Copyright (c) 2015 by Contributors | ||
* \file correlation.cc | ||
* \brief | ||
* \author Xu Dong | ||
*/ | ||
|
||
#include "./correlation-inl.h" | ||
|
||
namespace mshadow { | ||
|
||
template<typename Dtype> | ||
inline void CorrelationForward( const Tensor<cpu, 4, Dtype> &out, | ||
const Tensor<cpu, 4, Dtype> &data1, | ||
const Tensor<cpu, 4, Dtype> &data2, | ||
const Tensor<cpu, 4, Dtype> &tmp1, | ||
const Tensor<cpu, 4, Dtype> &tmp2, | ||
int top_channels_,int top_height_,int top_width_,int pad_size_,bool is_multiply, | ||
int max_displacement_,int kernel_size_,int neighborhood_grid_radius_,int neighborhood_grid_width_, | ||
int kernel_radius_,int stride1_,int stride2_ | ||
) { | ||
return ; | ||
} | ||
|
||
template<typename Dtype> | ||
inline void CorrelationBackward(const Tensor<cpu, 4, Dtype> &out_grad, | ||
const Tensor<cpu, 4, Dtype> &in_grad1, | ||
const Tensor<cpu, 4, Dtype> &in_grad2, | ||
const Tensor<cpu, 4, Dtype> &tmp1, | ||
const Tensor<cpu, 4, Dtype> &tmp2, | ||
int top_channels_,int top_height_,int top_width_,int pad_size_,bool is_multiply, | ||
int max_displacement_,int kernel_size_,int neighborhood_grid_radius_,int neighborhood_grid_width_, | ||
int kernel_radius_,int stride1_,int stride2_,int num, int channels,int height, int width | ||
) { | ||
|
||
return ; | ||
} | ||
} // namespace mshadow | ||
|
||
|
||
namespace mxnet { | ||
namespace op { | ||
template<> | ||
Operator *CreateOp<cpu>(CorrelationParam param) { | ||
return new CorrelationOp<cpu>(param); | ||
} | ||
|
||
Operator* CorrelationProp::CreateOperator(Context ctx) const { | ||
DO_BIND_DISPATCH(CreateOp, param_); | ||
} | ||
|
||
DMLC_REGISTER_PARAMETER(CorrelationParam); | ||
|
||
MXNET_REGISTER_OP_PROPERTY(Correlation, CorrelationProp) | ||
.describe("Apply correlation to inputs") | ||
.add_argument("data1", "Symbol", "Input data to the correlation.") | ||
.add_argument("data2", "Symbol", "Input data to the correlation.") | ||
.add_arguments(CorrelationParam::__FIELDS__()); | ||
|
||
} // namespace op | ||
} // namespace mxnet |
Oops, something went wrong.