diff --git a/src/operator/correlation-inl.h b/src/operator/correlation-inl.h new file mode 100644 index 000000000000..e6c5c0b39f40 --- /dev/null +++ b/src/operator/correlation-inl.h @@ -0,0 +1,250 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file Correlation.cu + * \brief Correlation pooling operator + * \author Xu Dong +*/ + +#ifndef MXNET_OPERATOR_ROI_POOLING_INL_H_ +#define MXNET_OPERATOR_ROI_POOLING_INL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "./mshadow_op.h" +#include "./operator_common.h" + +namespace mxnet { +namespace op { + +// Declare enumeration of input order to make code more intuitive. +// These enums are only visible within this header +namespace Correlation { +enum CorrelationOpInputs{kData1, kData2}; +enum CorrelationOpOutputs{kOut,kTemp1,kTemp2}; +} // namespace correlation + +struct CorrelationParam : public dmlc::Parameter { + uint32_t max_displacement; + uint32_t kernel_size; + uint32_t pad_size; + uint32_t stride1; + uint32_t stride2; + bool is_multiply; + DMLC_DECLARE_PARAMETER(CorrelationParam) { + DMLC_DECLARE_FIELD(kernel_size).set_default(1).describe("kernel size for Correlation"); + DMLC_DECLARE_FIELD(max_displacement).set_default(1).describe("Max displacement of Correlation "); + DMLC_DECLARE_FIELD(stride1).set_default(1).describe("stride between Correlation"); + DMLC_DECLARE_FIELD(stride2).set_default(1).describe("stride within neighbourhood"); + DMLC_DECLARE_FIELD(pad_size).set_default(0).describe("pad for Correlation"); + DMLC_DECLARE_FIELD(is_multiply).set_default(true).describe("operation type is either multiplication or subduction"); + } + +}; + +template +class CorrelationOp : public Operator { + public: + explicit CorrelationOp(CorrelationParam param) { + this->param_ = param; + } + + virtual void Forward( const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data, + const std::vector &aux_args) { + using namespace mshadow; + + CHECK_EQ(in_data.size(), 2); + CHECK_EQ(out_data.size(), 3); + Stream *s = ctx.get_stream(); + + Tensor data1 = in_data[Correlation::kData1].get(s); + Tensor data2 = in_data[Correlation::kData2].get(s); + Tensor out = out_data[Correlation::kOut].get(s); + Tensor tmp1 = out_data[Correlation::kTemp1].get(s); + Tensor tmp2 = out_data[Correlation::kTemp2].get(s); + + tmp1 = 0.0f; + tmp2 = 0.0f; + + CHECK_EQ(data1.CheckContiguous(), true); + CHECK_EQ(data2.CheckContiguous(), true); + CHECK_EQ(out.CheckContiguous(), true); + CHECK_EQ(tmp1.CheckContiguous(), true); + CHECK_EQ(tmp2.CheckContiguous(), true); + + paddedbottomheight = data1.shape_[2] + 2*param_.pad_size; + paddedbottomwidth = data1.shape_[3] + 2*param_.pad_size; + kernel_radius_ = (param_.kernel_size -1 )/2; + border_size_ = param_.max_displacement + kernel_radius_; + stride1 = param_.stride1; + stride2 = param_.stride2; + top_width_ = ceil((float)(paddedbottomwidth - border_size_*2) / (float)stride1); + top_height_ = ceil((float)(paddedbottomheight - border_size_*2) / (float)stride1); + neighborhood_grid_radius_ = param_.max_displacement / stride2; + neighborhood_grid_width_ = neighborhood_grid_radius_ * 2 + 1; + top_channels_ = neighborhood_grid_width_ * neighborhood_grid_width_; + + num = data1.shape_[0]; + channels = data1.shape_[1]; + height = data1.shape_[2]; + width = data1.shape_[3]; + + CorrelationForward(out, data1, data2, tmp1,tmp2,top_channels_,top_height_,top_width_, + param_.pad_size,param_.is_multiply,param_.max_displacement,param_.kernel_size,neighborhood_grid_radius_, + neighborhood_grid_width_,kernel_radius_,param_.stride1,param_.stride2); + } + + virtual void Backward(const OpContext &ctx, + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &req, + const std::vector &in_grad, + const std::vector &aux_args) { + using namespace mshadow; + Stream *s = ctx.get_stream(); + Tensor grad_data1 = in_grad[Correlation::kData1].get(s); + Tensor grad_data2 = in_grad[Correlation::kData2].get(s); + Tensor out_g = out_grad[Correlation::kOut].get(s); + Tensor tmp1 = out_data[Correlation::kTemp1].get(s); + Tensor tmp2 = out_data[Correlation::kTemp2].get(s); + + CHECK_EQ(grad_data1.CheckContiguous(), true); + CHECK_EQ(grad_data2.CheckContiguous(), true); + CHECK_EQ(out_g.CheckContiguous(), true); + CHECK_EQ(tmp1.CheckContiguous(), true); + CHECK_EQ(tmp2.CheckContiguous(), true); + + CorrelationBackward(out_g,grad_data1,grad_data2,tmp1,tmp2,top_channels_,top_height_,top_width_,param_.pad_size,param_.is_multiply, + param_.max_displacement,param_.kernel_size,neighborhood_grid_radius_,neighborhood_grid_width_,kernel_radius_,param_.stride1,param_.stride2,num,channels,height, width); + } + private: + CorrelationParam param_; + int paddedbottomheight; + int paddedbottomwidth; + uint32_t kernel_radius_; + uint32_t border_size_; + uint32_t stride1 ; + uint32_t stride2 ; + uint32_t top_width_; + uint32_t top_height_; + uint32_t neighborhood_grid_radius_; + uint32_t neighborhood_grid_width_ ; + uint32_t top_channels_ ; + int num; + int channels; + int height; + int width; +}; // class CorrelationOp + + +// Decalre Factory function +template +Operator* CreateOp(CorrelationParam param); + +#if DMLC_USE_CXX11 +class CorrelationProp : public OperatorProperty { + public: + std::vector ListArguments() const override { + return {"data1", "data2"}; + } + + std::vector ListOutputs() const override { + return {"output","tmp1","tmp2"}; + } + + int NumOutputs() const override { + return 3; + } + + int NumVisibleOutputs() const override { + return 1; + } + + void Init(const std::vector >& kwargs) override { + param_.Init(kwargs); + } + + std::map GetParams() const override { + return param_.__DICT__(); + } + bool InferShape(std::vector *in_shape, + std::vector *out_shape, + std::vector *aux_shape) const override { + + using namespace mshadow; + CHECK_EQ(in_shape->size(), 2) << "Input:[data1, data2]"; + TShape dshape1 = in_shape->at(Correlation::kData1); + TShape dshape2 = in_shape->at(Correlation::kData2); + CHECK_EQ(dshape1.ndim(), 4) << "data should be a 4D tensor"; + CHECK_EQ(dshape2.ndim(), 4) << "data should be a 4D tensor"; + + int paddedbottomheight; + int paddedbottomwidth; + uint32_t kernel_radius_; + uint32_t stride1 ; + uint32_t stride2 ; + uint32_t top_width_; + uint32_t top_height_; + uint32_t neighborhood_grid_radius_; + uint32_t neighborhood_grid_width_ ; + uint32_t top_channels_ ; + uint32_t border_size_; + + paddedbottomheight = dshape1[2] + 2*param_.pad_size; + paddedbottomwidth = dshape1[3] + 2*param_.pad_size; + kernel_radius_ = (param_.kernel_size -1 )/2; + border_size_ = param_.max_displacement + kernel_radius_; + stride1 = param_.stride1; + stride2 = param_.stride2; + top_width_ = ceil((float)(paddedbottomwidth - border_size_*2) / (float)stride1); + top_height_ = ceil((float)(paddedbottomheight - border_size_*2) / (float)stride1); + neighborhood_grid_radius_ = param_.max_displacement / stride2; + neighborhood_grid_width_ = neighborhood_grid_radius_ * 2 + 1; + top_channels_ = neighborhood_grid_width_ * neighborhood_grid_width_; + + CHECK_GE(top_width_, 1) << "Correlation cannot be done with current settings. Neighborhood and kernel don't fit in blob"; + CHECK_GE(top_height_, 1) << "Correlation cannot be done with current settings. Neighborhood and kernel don't fit in blob"; + + out_shape->clear(); + out_shape->push_back(Shape4(dshape1[0],top_channels_,top_height_,top_width_)); + out_shape->push_back(Shape4(dshape1[0],paddedbottomheight, paddedbottomwidth,dshape1[1])); + out_shape->push_back(Shape4(dshape1[0],paddedbottomheight, paddedbottomwidth,dshape1[1])); + return true; + } + + OperatorProperty* Copy() const override { + CorrelationProp* Correlation_sym = new CorrelationProp(); + Correlation_sym->param_ = this->param_; + return Correlation_sym; + } + + std::string TypeString() const override { + return "Correlation"; + } + + // decalre dependency and inplace optimization options + std::vector DeclareBackwardDependency( + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data) const override + { + return {out_grad[Correlation::kOut], out_data[Correlation::kTemp1], out_data[Correlation::kTemp2]}; + } + + Operator* CreateOperator(Context ctx) const override; + + private: + CorrelationParam param_; +}; // class CorrelationProp +#endif +} // namespace op +} // namespace mxnet +#endif // MXNET_OPERATOR_Correlation_INL_H_ diff --git a/src/operator/correlation.cc b/src/operator/correlation.cc new file mode 100644 index 000000000000..70e2ac35a328 --- /dev/null +++ b/src/operator/correlation.cc @@ -0,0 +1,61 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file correlation.cc + * \brief + * \author Xu Dong +*/ + +#include "./correlation-inl.h" + +namespace mshadow { + +template +inline void CorrelationForward( const Tensor &out, + const Tensor &data1, + const Tensor &data2, + const Tensor &tmp1, + const Tensor &tmp2, + int top_channels_,int top_height_,int top_width_,int pad_size_,bool is_multiply, + int max_displacement_,int kernel_size_,int neighborhood_grid_radius_,int neighborhood_grid_width_, + int kernel_radius_,int stride1_,int stride2_ + ) { + return ; +} + +template +inline void CorrelationBackward(const Tensor &out_grad, + const Tensor &in_grad1, + const Tensor &in_grad2, + const Tensor &tmp1, + const Tensor &tmp2, + int top_channels_,int top_height_,int top_width_,int pad_size_,bool is_multiply, + int max_displacement_,int kernel_size_,int neighborhood_grid_radius_,int neighborhood_grid_width_, + int kernel_radius_,int stride1_,int stride2_,int num, int channels,int height, int width + ) { + + return ; +} +} // namespace mshadow + + +namespace mxnet { +namespace op { +template<> +Operator *CreateOp(CorrelationParam param) { + return new CorrelationOp(param); +} + +Operator* CorrelationProp::CreateOperator(Context ctx) const { + DO_BIND_DISPATCH(CreateOp, param_); +} + +DMLC_REGISTER_PARAMETER(CorrelationParam); + +MXNET_REGISTER_OP_PROPERTY(Correlation, CorrelationProp) +.describe("Apply correlation to inputs") +.add_argument("data1", "Symbol", "Input data to the correlation.") +.add_argument("data2", "Symbol", "Input data to the correlation.") +.add_arguments(CorrelationParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/correlation.cu b/src/operator/correlation.cu new file mode 100644 index 000000000000..884afb3b0ec3 --- /dev/null +++ b/src/operator/correlation.cu @@ -0,0 +1,643 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file Correation.cu + * \brief Correlation operator + * \author Xu Dong +*/ +#include "./correlation-inl.h" +#include +#include +#include +#include + +#define ROUND_OFF 50000 +#define WARPS_PER_BLOCK 1 +#define THREADS_PER_WARP 32 + +#define CORRELATION_CUDA_CHECK(condition) \ + /* Code block avoids redefinition of cudaError_t error */ \ + do { \ + cudaError_t error = condition; \ + CHECK_EQ(error, cudaSuccess) << " " << cudaGetErrorString(error); \ + } while (0) + +#define CUDA_KERNEL_LOOP(i, n) \ +for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < (n); \ + i += blockDim.x * gridDim.x) + +namespace mshadow { +namespace cuda { + +// == Correlation Kernel +template +__global__ void CorrelateData(const int nthreads, int num, int topwidth, int topheight, int topchannels, int topcount, + int max_displacement, int neighborhood_grid_radius, int neighborhood_grid_width, int kernel_radius, int kernel_size, int stride1, int stride2, + int bottomwidth, int bottomheight, int bottomchannels, + const Dtype *bottom0, const Dtype *bottom1, Dtype *top) +{ + extern __shared__ char patch_data_char[]; + + Dtype *patch_data = (Dtype *)patch_data_char; + + // First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1 + int x1 = blockIdx.x*stride1 + max_displacement; + int y1 = blockIdx.y*stride1 + max_displacement; + int item = blockIdx.z; + int ch_off = threadIdx.x; + + // Load 3D patch into shared shared memory + for(int j = 0; j < kernel_size; j++) { // HEIGHT + for(int i = 0; i < kernel_size; i++) { // WIDTH + int ji_off = ((j * kernel_size) + i) * bottomchannels; + for(int ch = ch_off; ch < bottomchannels; ch += (THREADS_PER_WARP * WARPS_PER_BLOCK) ) { // CHANNELS + int idx1 = ((item * bottomheight + y1+j) * bottomwidth + x1+i) * bottomchannels + ch; + int idxPatchData = ji_off + ch; + patch_data[idxPatchData] = bottom0[idx1]; + } + } + } + + __syncthreads(); + + __shared__ Dtype sum[THREADS_PER_WARP * WARPS_PER_BLOCK]; + + // Compute correlation + for(int top_channel = 0; top_channel < topchannels; top_channel++) { + sum[ch_off] = 0; + + int s2o = (top_channel % neighborhood_grid_width - neighborhood_grid_radius) * stride2; + int s2p = (top_channel / neighborhood_grid_width - neighborhood_grid_radius) * stride2; + + for(int j = 0; j < kernel_size; j++) { // HEIGHT + for(int i = 0; i < kernel_size; i++) { // WIDTH + int ji_off = ((j * kernel_size) + i) * bottomchannels; + for(int ch = ch_off; ch < bottomchannels; ch += (THREADS_PER_WARP * WARPS_PER_BLOCK)) { // CHANNELS + int x2 = x1 + s2o; + int y2 = y1 + s2p; + + int idxPatchData = ji_off + ch; + int idx2 = ((item * bottomheight + y2+j) * bottomwidth + x2+i) * bottomchannels + ch; + + sum[ch_off] += patch_data[idxPatchData] * bottom1[idx2]; + } + } + } + + __syncthreads(); + + if(ch_off == 0) { + Dtype total_sum = 0; + for(int idx = 0; idx < THREADS_PER_WARP * WARPS_PER_BLOCK; idx++) { + total_sum += sum[idx]; + } + const int sumelems = kernel_size*kernel_size*bottomchannels; + const int index = ((top_channel*topheight + blockIdx.y)*topwidth)+blockIdx.x; + top[index + item*topcount] = total_sum / (float)sumelems; + } // Aggregate result of different threads + } +} + +// == Correlation Backward Pass Kernel (For data1) +template +__global__ void CorrelateDataBackward0(const int nthreads, int num, int item, int topwidth, int topheight, int topchannels, + int max_displacement, int neighborhood_grid_radius, int neighborhood_grid_width, int kernel_radius, int stride1, int stride2, + int bottomwidth, int bottomheight, int pbottomwidth, int pbottomheight, int bottomchannels, int bottomcount, int pad_size, + Dtype *bottom0diff, const Dtype *bottom1, const Dtype *topdiff) +{ + CUDA_KERNEL_LOOP(index, nthreads) { + int n = index % bottomchannels; //channels + int l = (index / bottomchannels) % bottomwidth + pad_size; //w-pos + int m = (index / bottomchannels / bottomwidth) % bottomheight + pad_size; //h-pos + + //Get X,Y ranges and clamp + // round_off is a trick to enable integer division with ceil, even for negative numbers + // We use a large offset, for the inner part not to become negative. + const int round_off = ROUND_OFF; + const int round_off_s1 = stride1 * round_off; + + // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: + int xmin = (l - 2*kernel_radius - max_displacement + round_off_s1 - 1) / stride1 + 1 - round_off; // ceil (l - 2*kernel_radius - max_displacement) / stride1 + int ymin = (m - 2*kernel_radius - max_displacement + round_off_s1 - 1) / stride1 + 1 - round_off; // ceil (l - 2*kernel_radius - max_displacement) / stride1 + + // Same here: + int xmax = (l - max_displacement + round_off_s1) / stride1 - round_off; // floor (l - max_displacement) / stride1 + int ymax = (m - max_displacement + round_off_s1) / stride1 - round_off; // floor (m - max_displacement) / stride1 + + + Dtype sum = 0; + if(xmax>=0 && ymax>=0 && (xmin<=topwidth-1) && (ymin<=topheight-1)) + { + xmin = max(0,xmin); + xmax = min(topwidth-1,xmax); + + ymin = max(0,ymin); + ymax = min(topheight-1,ymax); + + for(int p = -neighborhood_grid_radius; p <= neighborhood_grid_radius; p++) { + for(int o = -neighborhood_grid_radius; o <= neighborhood_grid_radius; o++) { + + // Get bottom1 data: + int s2o = stride2 * o; + int s2p = stride2 * p; + int idxbot1 = ((item * pbottomheight + (m+s2p)) * pbottomwidth + (l+s2o)) * bottomchannels + n; + Dtype bot1tmp = bottom1[idxbot1]; // bottom1[l+s2o,m+s2p,n] + + // Index offset for topdiff in following loops: + int op = (p+neighborhood_grid_radius) * neighborhood_grid_width + (o+neighborhood_grid_radius); // index [o,p] + int idxopoffset = (item * topchannels + op); + + for(int y = ymin; y <= ymax; y++) { + for(int x = xmin; x <= xmax; x++) { + int idxtopdiff = (idxopoffset * topheight + y) * topwidth + x; // topdiff[x,y,o,p] + sum += topdiff[idxtopdiff] * bot1tmp; + } + } + } + } + } + const int sumelems = (kernel_radius*2+1)*(kernel_radius*2+1)*bottomchannels; + const int bot0index = ((n * bottomheight) + (m-pad_size)) * bottomwidth + (l-pad_size); + bottom0diff[bot0index + item*bottomcount] = sum / (float)sumelems; + } + +} + +// == Correlation Backward Pass Kernel (For Blob 1) +template +__global__ void CorrelateDataBackward1(const int nthreads, int num, int item, int topwidth, int topheight, int topchannels, + int max_displacement, int neighborhood_grid_radius, int neighborhood_grid_width, int kernel_radius, int stride1, int stride2, + int bottomwidth, int bottomheight, int pbottomwidth, int pbottomheight, int bottomchannels, int bottomcount, int pad_size, + const Dtype *bottom0, Dtype *bottom1diff, const Dtype *topdiff) +{ + CUDA_KERNEL_LOOP(index, nthreads) { + //int l = index % bottomwidth + pad_size; //w-pos + //int m = (index / bottomwidth) % bottomheight + pad_size; //h-pos + //int n = (index / bottomwidth / bottomheight) % bottomchannels; //channels + int n = index % bottomchannels; //channels + int l = (index / bottomchannels) % bottomwidth + pad_size; //w-pos + int m = (index / bottomchannels / bottomwidth) % bottomheight + pad_size; //h-pos + + // round_off is a trick to enable integer division with ceil, even for negative numbers + // We use a large offset, for the inner part not to become negative. + const int round_off = ROUND_OFF; + const int round_off_s1 = stride1 * round_off; + + Dtype sum = 0; + for(int p = -neighborhood_grid_radius; p <= neighborhood_grid_radius; p++) { + for(int o = -neighborhood_grid_radius; o <= neighborhood_grid_radius; o++) { + + int s2o = stride2 * o; + int s2p = stride2 * p; + + //Get X,Y ranges and clamp + // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: + int xmin = (l - 2*kernel_radius - max_displacement - s2o + round_off_s1 - 1) / stride1 + 1 - round_off; // ceil (l - 2*kernel_radius - max_displacement - s2o) / stride1 + int ymin = (m - 2*kernel_radius - max_displacement - s2p + round_off_s1 - 1) / stride1 + 1 - round_off; // ceil (l - 2*kernel_radius - max_displacement - s2o) / stride1 + + // Same here: + int xmax = (l - max_displacement - s2o + round_off_s1) / stride1 - round_off; // floor (l - max_displacement - s2o) / stride1 + int ymax = (m - max_displacement - s2p + round_off_s1) / stride1 - round_off; // floor (m - max_displacement - s2p) / stride1 + + if(xmax>=0 && ymax>=0 && (xmin<=topwidth-1) && (ymin<=topheight-1)) + { + xmin = max(0,xmin); + xmax = min(topwidth-1,xmax); + + ymin = max(0,ymin); + ymax = min(topheight-1,ymax); + + // Get bottom0 data: + int idxbot0 = ((item * pbottomheight + (m-s2p)) * pbottomwidth + (l-s2o)) * bottomchannels + n; + Dtype bot0tmp = bottom0[idxbot0]; // bottom1[l+s2o,m+s2p,n] + + // Index offset for topdiff in following loops: + int op = (p+neighborhood_grid_radius) * neighborhood_grid_width + (o+neighborhood_grid_radius); // index [o,p] + int idxOpOffset = (item * topchannels + op); + + for(int y = ymin; y <= ymax; y++) { + for(int x = xmin; x <= xmax; x++) { + int idxtopdiff = (idxOpOffset * topheight + y) * topwidth + x; // topdiff[x,y,o,p] + sum += topdiff[idxtopdiff] * bot0tmp; + } + } + } + } + } + const int sumelems = (kernel_radius*2+1)*(kernel_radius*2+1)*bottomchannels; + const int bot1index = ((n * bottomheight) + (m-pad_size)) * bottomwidth + (l-pad_size); + bottom1diff[bot1index + item*bottomcount] = sum / (float)sumelems; + } + +} + +// == Correlation Kernel Subtraction +template +__global__ void CorrelateDataSubtract(const int nthreads, int num, int item, int topwidth, int topheight, int topchannels, int topcount, + int max_displacement, int neighborhood_grid_radius, int neighborhood_grid_width, int kernel_radius, int stride1, int stride2, + int bottomwidth, int bottomheight, int bottomchannels, + const Dtype *bottom0, const Dtype *bottom1, Dtype *top) +{ + CUDA_KERNEL_LOOP(index, nthreads) { + int x = index % topwidth; //w-pos + int y = (index / topwidth) % topheight; //h-pos + int c = (index / topwidth / topheight) % topchannels; //channels + + // Offset of patch in image 2 + int s2o = (c % neighborhood_grid_width - neighborhood_grid_radius) * stride2; + int s2p = (c / neighborhood_grid_width - neighborhood_grid_radius) * stride2; + + // First (upper left) position of kernel center in current neighborhood in image 1 + int x1 = x*stride1 + kernel_radius + max_displacement; + int y1 = y*stride1 + kernel_radius + max_displacement; + + // Iterate through 3D patch + Dtype sum = 0; + for(int j = -kernel_radius; j <= kernel_radius; j++) { // HEIGHT + for(int i = -kernel_radius; i <= kernel_radius; i++) { // WIDTH + for(int l = 0; l < bottomchannels; l++) { // CHANNELS + // Calculate position in image 2 + int x2 = x1 + s2o; + int y2 = y1 + s2p; + + // Indices in bottom data: (CH=l,W=x2,H=y2,N) + int idx1 = ((item * bottomheight + y1+j) * bottomwidth + x1+i) * bottomchannels + l; + int idx2 = ((item * bottomheight + y2+j) * bottomwidth + x2+i) * bottomchannels + l; + + // Do the correlation: + sum += fabsf(bottom0[idx1] - bottom1[idx2]); + } + } + } + const int sumelems = (kernel_radius*2+1)*(kernel_radius*2+1)*bottomchannels; + top[index + item*topcount] = sum / (float)sumelems; + } + +} + + +// == Correlation Backward Pass Kernel (For Blob 0) +template +__global__ void CorrelateDataBackward0Subtract(const int nthreads, int num, int item, int topwidth, int topheight, int topchannels, + int max_displacement, int neighborhood_grid_radius, int neighborhood_grid_width, int kernel_radius, int stride1, int stride2, + int bottomwidth, int bottomheight, int pbottomwidth, int pbottomheight, int bottomchannels, int bottomcount, int pad_size, + Dtype *bottom0diff, const Dtype *bottom0, const Dtype *bottom1, const Dtype *topdiff) +{ + CUDA_KERNEL_LOOP(index, nthreads) { + int l = index % bottomwidth + pad_size; //w-pos + int m = (index / bottomwidth) % bottomheight + pad_size; //h-pos + int n = (index / bottomwidth / bottomheight) % bottomchannels; //channels + + //Get X,Y ranges and clamp + // round_off is a trick to enable integer division with ceil, even for negative numbers + // We use a large offset, for the inner part not to become negative. + const int round_off = ROUND_OFF; + const int round_off_s1 = stride1 * round_off; + + // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: + int xmin = (l - 2*kernel_radius - max_displacement + round_off_s1 - 1) / stride1 + 1 - round_off; // ceil (l - 2*kernel_radius - max_displacement) / stride1 + int ymin = (m - 2*kernel_radius - max_displacement + round_off_s1 - 1) / stride1 + 1 - round_off; // ceil (l - 2*kernel_radius - max_displacement) / stride1 + + // Same here: + int xmax = (l - max_displacement + round_off_s1) / stride1 - round_off; // floor (l - max_displacement) / stride1 + int ymax = (m - max_displacement + round_off_s1) / stride1 - round_off; // floor (m - max_displacement) / stride1 + + + Dtype sum = 0; + if(xmax>=0 && ymax>=0 && (xmin<=topwidth-1) && (ymin<=topheight-1)) + { + xmin = max(0,xmin); + xmax = min(topwidth-1,xmax); + + ymin = max(0,ymin); + ymax = min(topheight-1,ymax); + + for(int p = -neighborhood_grid_radius; p <= neighborhood_grid_radius; p++) { + for(int o = -neighborhood_grid_radius; o <= neighborhood_grid_radius; o++) { + + // Get bottom1 data: + int s2o = stride2 * o; + int s2p = stride2 * p; + int idxbot = ((item * pbottomheight + (m+s2p)) * pbottomwidth + (l+s2o)) * bottomchannels + n; + Dtype bot0tmp = bottom0[idxbot]; // bottom0[l+s2o,m+s2p,n] + Dtype bot1tmp = bottom1[idxbot]; // bottom1[l+s2o,m+s2p,n] + Dtype sign = (bot0tmp >= bot1tmp) ? Dtype(1.0) : Dtype(-1.0); + + // Index offset for topdiff in following loops: + int op = (p+neighborhood_grid_radius) * neighborhood_grid_width + (o+neighborhood_grid_radius); // index [o,p] + int idxopoffset = (item * topchannels + op); + + for(int y = ymin; y <= ymax; y++) { + for(int x = xmin; x <= xmax; x++) { + int idxtopdiff = (idxopoffset * topheight + y) * topwidth + x; // topdiff[x,y,o,p] + sum += topdiff[idxtopdiff] * sign; + } + } + } + } + } + const int sumelems = (kernel_radius*2+1)*(kernel_radius*2+1)*bottomchannels; + bottom0diff[index + item*bottomcount] = sum / (float)sumelems; + } + +} + + +// == Correlation Backward Pass Kernel (For Blob 1) +template +__global__ void CorrelateDataBackward1Subtract(const int nthreads, int num, int item, int topwidth, int topheight, int topchannels, + int max_displacement, int neighborhood_grid_radius, int neighborhood_grid_width, int kernel_radius, int stride1, int stride2, + int bottomwidth, int bottomheight, int pbottomwidth, int pbottomheight, int bottomchannels, int bottomcount, int pad_size, + const Dtype *bottom0, const Dtype *bottom1, Dtype *bottom1diff, const Dtype *topdiff) +{ + CUDA_KERNEL_LOOP(index, nthreads) { + int l = index % bottomwidth + pad_size; //w-pos + int m = (index / bottomwidth) % bottomheight + pad_size; //h-pos + int n = (index / bottomwidth / bottomheight) % bottomchannels; //channels + + // round_off is a trick to enable integer division with ceil, even for negative numbers + // We use a large offset, for the inner part not to become negative. + const int round_off = ROUND_OFF; + const int round_off_s1 = stride1 * round_off; + + Dtype sum = 0; + for(int p = -neighborhood_grid_radius; p <= neighborhood_grid_radius; p++) { + for(int o = -neighborhood_grid_radius; o <= neighborhood_grid_radius; o++) { + + int s2o = stride2 * o; + int s2p = stride2 * p; + + int xmin = (l - 2*kernel_radius - max_displacement - s2o + round_off_s1 - 1) / stride1 + 1 - round_off; + int ymin = (m - 2*kernel_radius - max_displacement - s2p + round_off_s1 - 1) / stride1 + 1 - round_off; + + int xmax = (l - max_displacement - s2o + round_off_s1) / stride1 - round_off; + int ymax = (m - max_displacement - s2p + round_off_s1) / stride1 - round_off; + if((xmax>=0) && (ymax>=0) && (xmin<=topwidth-1) && (ymin<=topheight-1)) + { + xmin = max(0,xmin); + xmax = min(topwidth-1,xmax); + + ymin = max(0,ymin); + ymax = min(topheight-1,ymax); + + // Get bottom0 data: + int idxbot = ((item * pbottomheight + (m-s2p)) * pbottomwidth + (l-s2o)) * bottomchannels + n; + // bottom0[l+s2o,m+s2p,n] + Dtype bot0tmp = bottom0[idxbot]; + Dtype bot1tmp = bottom1[idxbot]; + Dtype sign = (bot0tmp >= bot1tmp) ? Dtype(-1.0) : Dtype(1.0); + + // Index offset for topdiff in following loops: + int op = (p+neighborhood_grid_radius) * neighborhood_grid_width + (o+neighborhood_grid_radius); // index [o,p] + int idxOpOffset = (item * topchannels + op); + + for(int y = ymin; y <= ymax; y++) { + for(int x = xmin; x <= xmax; x++) { + int idxtopdiff = (idxOpOffset * topheight + y) * topwidth + x; // topdiff[x,y,o,p] + sum += topdiff[idxtopdiff] * sign; + } + } + } + } + } + const int sumelems = (kernel_radius*2+1)*(kernel_radius*2+1)*bottomchannels; + bottom1diff[index + item*bottomcount] = sum / (float)sumelems; + } + +} +// == Forward +// == Dimension rearrangement Kernel + +template +__global__ void blob_rearrange_kernel2(const Dtype* in, Dtype* out, int num, int channels, int width, int height, int widthheight, int padding, int pwidthheight) +{ + // change shape from [batchsize,channel,y,x] to [batchsize,y,x,channel] + int xy = blockIdx.x*blockDim.x + threadIdx.x; + if(xy>=widthheight) + return; + + int ch = blockIdx.y; + int n = blockIdx.z; + + Dtype value=in[(n*channels+ch)*widthheight+xy]; + + __syncthreads(); + + int xpad = (xy % width + padding); + int ypad = (xy / width + padding); + int xypad = ypad * (width+2*padding) + xpad; + + out[(n*pwidthheight+xypad)*channels + ch] = value; +} +template +void Forward_gpu( + const Tensor &out, + const Tensor &data1, + const Tensor &data2, + const Tensor &tmp1, + const Tensor &tmp2, + int top_channels_,int top_height_,int top_width_,int pad_size_,bool is_multiply, + int max_displacement_,int kernel_size_,int neighborhood_grid_radius_,int neighborhood_grid_width_, + int kernel_radius_,int stride1_,int stride2_,cudaStream_t stream,cudaStream_t stream_tmp1,cudaStream_t stream_tmp2) +{ + const Dtype *bottom_data1 = data1.dptr_; + const Dtype *bottom_data2 = data2.dptr_; + Dtype *rbot1 = tmp1.dptr_; + Dtype *rbot2 = tmp2.dptr_; + Dtype *top = out.dptr_; + const int bnum = data1.size(0); + const int bchannels = data1.size(1); + const int bheight = data1.size(2); + const int bwidth = data1.size(3); + const int bwidthheight = bwidth * bheight; + const int topcount = top_width_ * top_height_ * top_channels_; + + dim3 threadsPerBlock(THREADS_PER_WARP * WARPS_PER_BLOCK); + int threads_per_block=16; + dim3 totalBlocksRearr((bwidthheight-1)/threads_per_block+1, bchannels, bnum); + const int pwidthheight = (bwidth + 2 * pad_size_) * (bheight + 2 * pad_size_); + + blob_rearrange_kernel2<<>>(bottom_data1,rbot1,bnum,bchannels,bwidth,bheight,bwidthheight,pad_size_,pwidthheight); + blob_rearrange_kernel2<<>>(bottom_data2,rbot2,bnum,bchannels,bwidth,bheight,bwidthheight,pad_size_,pwidthheight); + + const int num = bnum; + const int channels = bchannels; + const int height = bheight + 2*pad_size_; + const int width = bwidth + 2*pad_size_; + + const int shared_memory_per_block = (kernel_size_*kernel_size_)*bchannels; + + if(is_multiply == true) { + // CorrelationLayer + int topThreadCount = topcount; + + dim3 totalBlocksCorr(top_width_, top_height_, num); + + CorrelateData<<>>( + topThreadCount, + num, top_width_, top_height_, top_channels_, topcount, + max_displacement_, neighborhood_grid_radius_, neighborhood_grid_width_, kernel_radius_, kernel_size_, + stride1_, stride2_, + width, height, channels, + rbot1, rbot2, top + ); + CORRELATION_CUDA_CHECK(cudaPeekAtLastError()); + + } else { + // CorrelationLayer + for(int n = 0; n < num; n++) { + + int topThreadCount = topcount; + const int gridSize = (topThreadCount + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock; + CorrelateDataSubtract<<>>( + topThreadCount, + num, n, top_width_, top_height_, top_channels_, topcount, + max_displacement_, neighborhood_grid_radius_, neighborhood_grid_width_, kernel_radius_, + stride1_, stride2_,width, height, channels,rbot1, rbot2, top ); + + CORRELATION_CUDA_CHECK(cudaPeekAtLastError()); + } + } +} + +template +void Backward_gpu( + const Tensor &out_grad, + const Tensor &in_grad1, + const Tensor &in_grad2, + const Tensor &tmp1, + const Tensor &tmp2, + int top_channels_,int top_height_,int top_width_,int pad_size_,bool is_multiply, + int max_displacement_,int kernel_size_,int neighborhood_grid_radius_,int neighborhood_grid_width_, + int kernel_radius_,int stride1_,int stride2_,cudaStream_t stream0,cudaStream_t stream1,int num,int channels,int height,int width) +{ + + // Get top diff, compute bottom diff + const Dtype* top_diff = out_grad.dptr_; + + Dtype* bottom0_diff = in_grad1.dptr_; + Dtype* bottom1_diff = in_grad2.dptr_; + + const Dtype* rbot1 = tmp1.dptr_; + const Dtype* rbot2 = tmp2.dptr_; + + const int paddedheight = height + 2*pad_size_; + const int paddedwidth = width + 2*pad_size_; + + const int bottomcount = channels * height * width; + int botThreadCount = bottomcount; + const int gridSize = (botThreadCount + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock; + + // CorrelationLayerBackward + + if(is_multiply == true) { + + // == Run kernel Backward 0 + dim3 totalBlocksBackward0(width, height, channels * num); //First dim is fastest + const int buffer_size_backw0 = ((int)ceil((float)(2 * kernel_radius_) / (float)stride1_) + 1) * top_channels_; + + // == Run kernel Backward 0 + for(int n = 0; n < num; n++) { + + CorrelateDataBackward0<<>>( + botThreadCount, + num, n, top_width_, top_height_, top_channels_, + max_displacement_, neighborhood_grid_radius_, neighborhood_grid_width_, kernel_radius_, + stride1_, stride2_, + width, height, paddedwidth, paddedheight, channels, bottomcount, pad_size_, + bottom0_diff, rbot2, top_diff + ); + + CORRELATION_CUDA_CHECK(cudaPeekAtLastError()); + } + + // == Run kernel Backward 1 + for(int n = 0; n < num; n++) { + CorrelateDataBackward1<<>>( + botThreadCount, + num, n, top_width_, top_height_, top_channels_, + max_displacement_, neighborhood_grid_radius_, neighborhood_grid_width_, kernel_radius_, + stride1_, stride2_, + width, height, paddedwidth, paddedheight, channels, bottomcount, pad_size_, + rbot1, bottom1_diff, top_diff + ); + + CORRELATION_CUDA_CHECK(cudaPeekAtLastError()); + } + + } else { + for(int n = 0; n < num; n++) { + //Bottom0: + CorrelateDataBackward0Subtract<<>>( + botThreadCount, + num, n, top_width_, top_height_, top_channels_, + max_displacement_, neighborhood_grid_radius_, neighborhood_grid_width_, kernel_radius_, + stride1_, stride2_, + width, height, paddedwidth, paddedheight, channels, bottomcount, pad_size_, + bottom0_diff, rbot1, rbot2, top_diff + ); + + CORRELATION_CUDA_CHECK(cudaPeekAtLastError()); + } + + for(int n = 0; n < num; n++) { + //Bottom1: + CorrelateDataBackward1Subtract<<>>( + botThreadCount, + num, n, top_width_, top_height_, top_channels_, + max_displacement_, neighborhood_grid_radius_, neighborhood_grid_width_, kernel_radius_, + stride1_, stride2_, + width, height, paddedwidth, paddedheight, channels, bottomcount, pad_size_, + rbot1, rbot2, bottom1_diff, top_diff + ); + CORRELATION_CUDA_CHECK(cudaPeekAtLastError()); + } + } +} + +} // namespace cuda + +template +inline void CorrelationForward( const Tensor &out, + const Tensor &data1, + const Tensor &data2, + const Tensor &tmp1, + const Tensor &tmp2, + int top_channels_,int top_height_,int top_width_,int pad_size_,bool is_multiply, + int max_displacement_,int kernel_size_,int neighborhood_grid_radius_,int neighborhood_grid_width_, + int kernel_radius_,int stride1_,int stride2_ + ) { + cudaStream_t stream = Stream::GetStream(out.stream_); + cudaStream_t stream_tmp1 = Stream::GetStream(tmp1.stream_); + cudaStream_t stream_tmp2 = Stream::GetStream(tmp2.stream_); + cuda::Forward_gpu(out, data1, data2, tmp1,tmp2,top_channels_,top_height_,top_width_,pad_size_,is_multiply,max_displacement_,kernel_size_, + neighborhood_grid_radius_,neighborhood_grid_width_,kernel_radius_,stride1_,stride2_,stream,stream_tmp1,stream_tmp2); +} + +template +inline void CorrelationBackward(const Tensor &out_grad, + const Tensor &in_grad1, + const Tensor &in_grad2, + const Tensor &tmp1, + const Tensor &tmp2, + int top_channels_,int top_height_,int top_width_,int pad_size_,bool is_multiply, + int max_displacement_,int kernel_size_,int neighborhood_grid_radius_,int neighborhood_grid_width_, + int kernel_radius_,int stride1_,int stride2_,int num,int channels,int height,int width + ) { + cudaStream_t stream0 = Stream::GetStream(in_grad1.stream_); + cudaStream_t stream1 = Stream::GetStream(in_grad2.stream_); + cuda::Backward_gpu( out_grad,in_grad1,in_grad2,tmp1,tmp2,top_channels_,top_height_,top_width_,pad_size_,is_multiply, + max_displacement_,kernel_size_,neighborhood_grid_radius_,neighborhood_grid_width_,kernel_radius_,stride1_,stride2_,stream0,stream1,num,channels,height, width); +} + +} // namespace mshadow + +namespace mxnet { +namespace op { + +template<> +Operator* CreateOp(CorrelationParam param) { + return new CorrelationOp(param); +} +} // namespace op +} // namespace mxnet