From cfd2fe995702a225bbd7309937c4fcd9439b219a Mon Sep 17 00:00:00 2001 From: xu dong Date: Wed, 19 Oct 2016 14:01:44 +0800 Subject: [PATCH] add 1D correlation --- src/operator/correlation1D-inl.h | 243 ++++++++++++++++++ src/operator/correlation1D.cc | 59 +++++ src/operator/correlation1D.cu | 425 +++++++++++++++++++++++++++++++ 3 files changed, 727 insertions(+) create mode 100644 src/operator/correlation1D-inl.h create mode 100644 src/operator/correlation1D.cc create mode 100644 src/operator/correlation1D.cu diff --git a/src/operator/correlation1D-inl.h b/src/operator/correlation1D-inl.h new file mode 100644 index 000000000000..3705441e1492 --- /dev/null +++ b/src/operator/correlation1D-inl.h @@ -0,0 +1,243 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file correlation1D-inl.h + * \brief correlation1D operator and symbol + * \author Xu Dong +*/ +#ifndef MXNET_OPERATOR_CORRELATION1D_INL_H_ +#define MXNET_OPERATOR_CORRELATION1D_INL_H_ +#include +#include +#include +#include +#include +#include +#include +#include "./mshadow_op.h" +#include "./operator_common.h" +namespace mxnet { +namespace op { +// Declare enumeration of input order to make code more intuitive. +// These enums are only visible within this header +namespace Correlation1D { +enum Correlation1DOpInputs{kData1, kData2}; +enum Correlation1DOpOutputs{kOut, kTemp1, kTemp2}; +} // namespace Correlation1D +struct Correlation1DParam : public dmlc::Parameter { + uint32_t max_displacement; + uint32_t kernel_size; + uint32_t pad_size; + uint32_t stride1; + uint32_t stride2; + uint32_t single_side; + DMLC_DECLARE_PARAMETER(Correlation1DParam) { + DMLC_DECLARE_FIELD(kernel_size).set_default(1) + .describe("kernel size for Correlation1D must be an odd number"); + DMLC_DECLARE_FIELD(max_displacement).set_default(1) + .describe("Max displacement of Correlation1D "); + DMLC_DECLARE_FIELD(stride1).set_default(1) + .describe("stride1 quantize data1 globally"); + DMLC_DECLARE_FIELD(stride2).set_default(1) + .describe("stride2 quantize data2 within the neighborhood centered around data1"); + DMLC_DECLARE_FIELD(pad_size).set_default(0) + .describe("pad for Correlation1D"); + DMLC_DECLARE_FIELD(single_side).set_default(0) + .describe("0: both side, -1: to left, 1: to right"); + } +}; +template +class Correlation1DOp : public Operator { + public: + explicit Correlation1DOp(Correlation1DParam param) { + this->param_ = param; + } + virtual void Forward(const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data, + const std::vector &aux_args) { + using namespace mshadow; + CHECK_EQ(in_data.size(), 2); + CHECK_EQ(out_data.size(), 3); + Stream *s = ctx.get_stream(); + Tensor data1 = in_data[Correlation1D::kData1].get(s); + Tensor data2 = in_data[Correlation1D::kData2].get(s); + Tensor out = out_data[Correlation1D::kOut].get(s); + Tensor tmp1 = out_data[Correlation1D::kTemp1].get(s); + Tensor tmp2 = out_data[Correlation1D::kTemp2].get(s); + tmp1 = 0.0f; + tmp2 = 0.0f; + out = 0.0f; + CHECK_EQ(data1.CheckContiguous(), true); + CHECK_EQ(data2.CheckContiguous(), true); + CHECK_EQ(out.CheckContiguous(), true); + CHECK_EQ(tmp1.CheckContiguous(), true); + CHECK_EQ(tmp2.CheckContiguous(), true); + + paddedbottomheight = data1.shape_[2]; + paddedbottomwidth = data1.shape_[3] + 2 * param_.pad_size; + + kernel_radius_ = (param_.kernel_size - 1) / 2; + border_size_ = param_.max_displacement + kernel_radius_; + stride1 = param_.stride1; + stride2 = param_.stride2; + top_width_ = ceil(static_cast(paddedbottomwidth - border_size_ * 2)\ + / static_cast(stride1)); + top_height_ = ceil(static_cast(paddedbottomheight - kernel_radius_ * 2)\ + / static_cast(stride1)); + neighborhood_grid_radius_ = param_.max_displacement / stride2; + if (param_.single_side!=0) + neighborhood_grid_width_ = neighborhood_grid_radius_ + 1; + else + neighborhood_grid_width_ = neighborhood_grid_radius_ * 2 + 1; + + top_channels_ = neighborhood_grid_width_; + num = data1.shape_[0]; + channels = data1.shape_[1]; + height = data1.shape_[2]; + width = data1.shape_[3]; + Correlation1DForward(out, data1, data2, tmp1, tmp2, top_channels_, top_height_, top_width_, + param_.pad_size, param_.single_side, + param_.max_displacement, param_.kernel_size, + neighborhood_grid_radius_, neighborhood_grid_width_, + kernel_radius_, param_.stride1, param_.stride2); + } + virtual void Backward(const OpContext &ctx, + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &req, + const std::vector &in_grad, + const std::vector &aux_args) { + using namespace mshadow; + Stream *s = ctx.get_stream(); + Tensor grad_data1 = in_grad[Correlation1D::kData1].get(s); + Tensor grad_data2 = in_grad[Correlation1D::kData2].get(s); + Tensor out_g = out_grad[Correlation1D::kOut].get(s); + Tensor tmp1 = out_data[Correlation1D::kTemp1].get(s); + Tensor tmp2 = out_data[Correlation1D::kTemp2].get(s); + CHECK_EQ(grad_data1.CheckContiguous(), true); + CHECK_EQ(grad_data2.CheckContiguous(), true); + CHECK_EQ(out_g.CheckContiguous(), true); + CHECK_EQ(tmp1.CheckContiguous(), true); + CHECK_EQ(tmp2.CheckContiguous(), true); + Correlation1DBackward(out_g, grad_data1, grad_data2, tmp1, tmp2, top_channels_, + top_height_, top_width_, param_.pad_size, param_.single_side, + param_.max_displacement, param_.kernel_size, neighborhood_grid_radius_, + neighborhood_grid_width_, kernel_radius_, param_.stride1, param_.stride2, + num, channels, height, width); + } + + private: + Correlation1DParam param_; + int paddedbottomheight; + int paddedbottomwidth; + uint32_t kernel_radius_; + uint32_t border_size_; + uint32_t stride1; + uint32_t stride2; + uint32_t top_width_; + uint32_t top_height_; + uint32_t neighborhood_grid_radius_; + uint32_t neighborhood_grid_width_; + uint32_t top_channels_; + int num; + int channels; + int height; + int width; +}; // class Correlation1DOp +// Decalre Factory function +template +Operator* CreateOp(Correlation1DParam param); +#if DMLC_USE_CXX11 +class Correlation1DProp : public OperatorProperty { + public: + std::vector ListArguments() const override { + return {"data1", "data2"}; + } + std::vector ListOutputs() const override { + return {"output", "tmp1", "tmp2"}; + } + int NumOutputs() const override { + return 3; + } + int NumVisibleOutputs() const override { + return 1; + } +void Init(const std::vector >& kwargs) override { + param_.Init(kwargs); + } + std::map GetParams() const override { + return param_.__DICT__(); + } + bool InferShape(std::vector *in_shape, + std::vector *out_shape, + std::vector *aux_shape) const override { + using namespace mshadow; + CHECK_EQ(in_shape->size(), 2) << "Input:[data1, data2]"; + TShape dshape1 = in_shape->at(Correlation1D::kData1); + TShape dshape2 = in_shape->at(Correlation1D::kData2); + CHECK_EQ(dshape1.ndim(), 4) << "data should be a 4D tensor"; + CHECK_EQ(dshape2.ndim(), 4) << "data should be a 4D tensor"; + int paddedbottomheight; + int paddedbottomwidth; + uint32_t kernel_radius_; + uint32_t stride1; + uint32_t stride2; + uint32_t top_width_; + uint32_t top_height_; + uint32_t neighborhood_grid_radius_; + uint32_t neighborhood_grid_width_; + uint32_t top_channels_; + uint32_t border_size_; + paddedbottomheight = dshape1[2]; + paddedbottomwidth = dshape1[3] + 2*param_.pad_size; + kernel_radius_ = (param_.kernel_size -1)/2; + border_size_ = param_.max_displacement + kernel_radius_; + stride1 = param_.stride1; + stride2 = param_.stride2; + top_width_ = ceil(static_cast(paddedbottomwidth - border_size_ * 2)\ + / static_cast(stride1)); + top_height_ = ceil(static_cast(paddedbottomheight - kernel_radius_ * 2)\ + / static_cast(stride1)); + neighborhood_grid_radius_ = param_.max_displacement / stride2; + if (param_.single_side!=0) + neighborhood_grid_width_ = neighborhood_grid_radius_ + 1; + else + neighborhood_grid_width_ = neighborhood_grid_radius_ * 2 + 1; + top_channels_ = neighborhood_grid_width_; + CHECK_GE(top_width_, 1) << + "Correlation1D cannot be done with current settings.Neighborhood and kernel don't fit in blob"; + CHECK_GE(top_height_, 1) << + "Correlation1D cannot be done with current settings.Neighborhood and kernel don't fit in blob"; + out_shape->clear(); + out_shape->push_back(Shape4(dshape1[0], top_channels_, top_height_, top_width_)); + out_shape->push_back(Shape4(dshape1[0], paddedbottomheight, paddedbottomwidth, dshape1[1])); + out_shape->push_back(Shape4(dshape1[0], paddedbottomheight, paddedbottomwidth, dshape1[1])); + return true; + } + OperatorProperty* Copy() const override { + Correlation1DProp* Correlation1D_sym = new Correlation1DProp(); + Correlation1D_sym->param_ = this->param_; + return Correlation1D_sym; + } + std::string TypeString() const override { + return "Correlation1D"; + } + // decalre dependency and inplace optimization options + std::vector DeclareBackwardDependency( + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data) const override { + return {out_grad[Correlation1D::kOut], + out_data[Correlation1D::kTemp1], out_data[Correlation1D::kTemp2]}; +} + Operator* CreateOperator(Context ctx) const override; + + private: + Correlation1DParam param_; +}; // class Correlation1DProp +#endif +} // namespace op +} // namespace mxnet +#endif // MXNET_OPERATOR_CORRELATION1D_INL_H_ diff --git a/src/operator/correlation1D.cc b/src/operator/correlation1D.cc new file mode 100644 index 000000000000..50d1975e09be --- /dev/null +++ b/src/operator/correlation1D.cc @@ -0,0 +1,59 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file correlation1D.cc + * \brief correlation1D op + * \author Xu Dong +*/ +#include "./correlation1D-inl.h" +#include "./mshadow_op.h" + +namespace mshadow { +template +inline void Correlation1DForward(const Tensor &out, + const Tensor &data1, + const Tensor &data2, + const Tensor &tmp1, + const Tensor &tmp2, + int top_channels_, int top_height_, int top_width_, + int pad_size_, int single_side, + int max_displacement_, int kernel_size_, + int neighborhood_grid_radius_, int neighborhood_grid_width_, + int kernel_radius_, int stride1_, int stride2_) + { + printf("No implementation"); +} +template +inline void Correlation1DBackward(const Tensor &out_grad, + const Tensor &in_grad1, + const Tensor &in_grad2, + const Tensor &tmp1, + const Tensor &tmp2, + int top_channels_, int top_height_, + int top_width_, int pad_size_, + int single_side, int max_displacement_, + int kernel_size_, int neighborhood_grid_radius_, + int neighborhood_grid_width_, + int kernel_radius_, int stride1_, + int stride2_, int num, + int channels, int height, int width + ) { + printf("No implementation"); + } +} // namespace mshadow +namespace mxnet { +namespace op { +template<> +Operator *CreateOp(Correlation1DParam param) { + return new Correlation1DOp(param); +} +Operator* Correlation1DProp::CreateOperator(Context ctx) const { + DO_BIND_DISPATCH(CreateOp, param_); +} +DMLC_REGISTER_PARAMETER(Correlation1DParam); +MXNET_REGISTER_OP_PROPERTY(Correlation1D,Correlation1DProp) +.describe("Apply correlation1D to inputs") +.add_argument("data1", "Symbol", "Input data1 to the correlation1D.") +.add_argument("data2", "Symbol", "Input data2 to the correlation1D.") +.add_arguments(Correlation1DParam::__FIELDS__()); +} // namespace op +} // namespace mxnet diff --git a/src/operator/correlation1D.cu b/src/operator/correlation1D.cu new file mode 100644 index 000000000000..a39a25b706aa --- /dev/null +++ b/src/operator/correlation1D.cu @@ -0,0 +1,425 @@ +/*! + * Copyright [2016] + * \file Correation.cu + * \brief correlation1D operator + * \author Xu Dong +*/ +#include "./correlation1D-inl.h" +#include +#include +#include +#include + +#define ROUND_OFF 50000 +#define WARPS_PER_BLOCK 1 +#define THREADS_PER_WARP 32 +#define correlation1D_CUDA_CHECK(condition) \ + /* Code block avoids redefinition of cudaError_t error */ \ + do { \ + cudaError_t error = condition; \ + CHECK_EQ(error, cudaSuccess) << " " << cudaGetErrorString(error); \ + } while (0) +#define CUDA_KERNEL_LOOP(i, n) \ +for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < (n); \ + i += blockDim.x * gridDim.x) + +namespace mshadow { +namespace cuda { +// == correlation1D Kernel +template +__global__ void Correlate1DData(const int nthreads, int num, int topwidth, int topheight, int topchannels, int topcount, + int max_displacement, int x_shift, int neighborhood_grid_width, int kernel_radius, int kernel_size, int stride1, int stride2, + int bottomwidth, int bottomheight, int bottomchannels, + const Dtype *bottom0, const Dtype *bottom1, Dtype *top) +{ + extern __shared__ char patch_data_char[]; + + Dtype *patch_data = (Dtype *)patch_data_char; + + // First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1 + int x1 = blockIdx.x*stride1 + max_displacement; + int y1 = blockIdx.y*stride1; + int item = blockIdx.z; + int ch_off = threadIdx.x; + + // Load 3D patch into shared shared memory + for(int j = 0; j < kernel_size; j++) { // HEIGHT + for(int i = 0; i < kernel_size; i++) { // WIDTH + int ji_off = ((j * kernel_size) + i) * bottomchannels; + for(int ch = ch_off; ch < bottomchannels; ch += (WARPS_PER_BLOCK*THREADS_PER_WARP)) { // CHANNELS + int idx1 = ((item * bottomheight + y1+j) * bottomwidth + x1+i) * bottomchannels + ch; + int idxPatchData = ji_off + ch; + patch_data[idxPatchData] = bottom0[idx1]; + } + } + } + + __syncthreads(); + + __shared__ Dtype sum[WARPS_PER_BLOCK*THREADS_PER_WARP]; + + // Compute + for(int top_channel = 0; top_channel < topchannels; top_channel++) { + sum[ch_off] = 0; + + int s2o = (top_channel % neighborhood_grid_width + x_shift) * stride2; + + for(int j = 0; j < kernel_size; j++) { // HEIGHT + for(int i = 0; i < kernel_size; i++) { // WIDTH + int ji_off = ((j * kernel_size) + i) * bottomchannels; + for(int ch = ch_off; ch < bottomchannels; ch += (WARPS_PER_BLOCK*THREADS_PER_WARP)) { // CHANNELS + int x2 = x1 + s2o; + + int idxPatchData = ji_off + ch; + int idx2 = ((item * bottomheight + y1+j) * bottomwidth + x2+i) * bottomchannels + ch; + + sum[ch_off] += patch_data[idxPatchData] * bottom1[idx2]; + } + } + } + + __syncthreads(); + + if(ch_off == 0) { + Dtype total_sum = 0; + for(int idx = 0; idx < WARPS_PER_BLOCK*THREADS_PER_WARP; idx++) { + total_sum += sum[idx]; + } + const int sumelems = kernel_size*kernel_size*bottomchannels; + const int index = ((top_channel*topheight + blockIdx.y)*topwidth)+blockIdx.x; + top[index + item*topcount] = total_sum / (float)sumelems; + } + } + // Aggregate +} + +// == correlation1D Backward Pass Kernel (For data1) +template +__global__ void Correlate1DDataBackward0(const int nthreads, int num, int item, + int topwidth, int topheight, int topchannels, + int max_displacement, int x_shift, + int neighborhood_grid_width, int kernel_radius, int stride1, int stride2, + int bottomwidth, int bottomheight, int pbottomwidth, int pbottomheight, + int bottomchannels, int bottomcount, int pad_size, + Dtype *bottom0diff, const Dtype *bottom1, const Dtype *topdiff) { + CUDA_KERNEL_LOOP(index, nthreads) { + int n = index % bottomchannels; //channels + int l = (index / bottomchannels) % bottomwidth + pad_size; //w-pos + int m = (index / bottomchannels / bottomwidth) % bottomheight; //h-pos + + //Get X,Y ranges and clamp + // round_off is a trick to enable integer division with ceil, even for negative numbers + // We use a large offset, for the inner part not to become negative. + const int round_off = ROUND_OFF; + const int round_off_s1 = stride1 * round_off; + + // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: + int xmin = (l - 2*kernel_radius - max_displacement + round_off_s1 - 1) / stride1 + 1 - round_off; // ceil (l - 2*kernel_radius - max_displacement) / stride1 + int ymin = (m - 2*kernel_radius - 0 + round_off_s1 - 1) / stride1 + 1 - round_off; // ceil (l - 2*kernel_radius - max_displacement) / stride1 + + // Same here: + int xmax = (l - max_displacement + round_off_s1) / stride1 - round_off; // floor (l - max_displacement) / stride1 + int ymax = (m - 0 + round_off_s1) / stride1 - round_off; // floor (m - max_displacement) / stride1 + + + Dtype sum = 0; + if(xmax>=0 && ymax>=0 && (xmin<=topwidth-1) && (ymin<=topheight-1)) + { + xmin = max(0,xmin); + xmax = min(topwidth-1,xmax); + + ymin = max(0,ymin); + ymax = min(topheight-1,ymax); + + { + for(int o = x_shift; o < x_shift + neighborhood_grid_width; o++) { + + // Get bottom1 data: + int s2o = stride2 * o; + int idxbot1 = ((item * pbottomheight + m) * pbottomwidth + (l+s2o)) * bottomchannels + n; + Dtype bot1tmp = bottom1[idxbot1]; // bottom1[l+s2o,m,n] + + // Index offset for topdiff in following loops: + int op = (o-x_shift); // index [o,p] + int idxopoffset = (item * topchannels + op); + + for(int y = ymin; y <= ymax; y++) { + for(int x = xmin; x <= xmax; x++) { + int idxtopdiff = (idxopoffset * topheight + y) * topwidth + x; // topdiff[x,y,o,p] + sum += topdiff[idxtopdiff] * bot1tmp; + } + } + } + } + } + const int sumelems = (kernel_radius*2+1)*(kernel_radius*2+1)*bottomchannels; + const int bot0index = ((n * bottomheight) + m) * bottomwidth + (l-pad_size); + bottom0diff[bot0index + item*bottomcount] = sum / (float)sumelems; + } +} + +// == Correlation Backward Pass Kernel (For Blob 1) +template +__global__ void Correlate1DDataBackward1(const int nthreads, + int num, int item, int topwidth, int topheight, int topchannels, + int max_displacement, int x_shift, + int neighborhood_grid_width, int kernel_radius, int stride1, int stride2, + int bottomwidth, int bottomheight, int pbottomwidth, int pbottomheight, + int bottomchannels, int bottomcount, int pad_size, + const Dtype *bottom0, Dtype *bottom1diff, const Dtype *topdiff) { + CUDA_KERNEL_LOOP(index, nthreads) { + //int l = index % bottomwidth + pad_size; //w-pos + //int m = (index / bottomwidth) % bottomheight + pad_size; //h-pos + //int n = (index / bottomwidth / bottomheight) % bottomchannels; //channels + int n = index % bottomchannels; //channels + int l = (index / bottomchannels) % bottomwidth + pad_size; //w-pos + int m = (index / bottomchannels / bottomwidth) % bottomheight; //h-pos + + // round_off is a trick to enable integer division with ceil, even for negative numbers + // We use a large offset, for the inner part not to become negative. + const int round_off = ROUND_OFF; + const int round_off_s1 = stride1 * round_off; + + Dtype sum = 0; + { + + for(int o = x_shift; o < x_shift + neighborhood_grid_width; o++) { + + int s2o = stride2 * o; + + //Get X,Y ranges and clamp + // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: + int xmin = (l - 2*kernel_radius - max_displacement - s2o + round_off_s1 - 1) / stride1 + 1 - round_off; // ceil (l - 2*kernel_radius - max_displacement - s2o) / stride1 + int ymin = (m - 2*kernel_radius - 0 - 0 + round_off_s1 - 1) / stride1 + 1 - round_off; // ceil (l - 2*kernel_radius - max_displacement - s2o) / stride1 + + // Same here: + int xmax = (l - max_displacement - s2o + round_off_s1) / stride1 - round_off; // floor (l - max_displacement - s2o) / stride1 + int ymax = (m - 0 - 0 + round_off_s1) / stride1 - round_off; // floor (m - max_displacement - 0) / stride1 + + if(xmax>=0 && ymax>=0 && (xmin<=topwidth-1) && (ymin<=topheight-1)) + { + xmin = max(0,xmin); + xmax = min(topwidth-1,xmax); + + ymin = max(0,ymin); + ymax = min(topheight-1,ymax); + + // Get bottom0 data: + int idxbot0 = ((item * pbottomheight + m) * pbottomwidth + (l-s2o)) * bottomchannels + n; + Dtype bot0tmp = bottom0[idxbot0]; // bottom1[l+s2o,m,n] + + // Index offset for topdiff in following loops: + int op = (o-x_shift); // index [o,p] + int idxOpOffset = (item * topchannels + op); + + for(int y = ymin; y <= ymax; y++) { + for(int x = xmin; x <= xmax; x++) { + int idxtopdiff = (idxOpOffset * topheight + y) * topwidth + x; // topdiff[x,y,o,p] + sum += topdiff[idxtopdiff] * bot0tmp; + } + } + } + } + } + const int sumelems = (kernel_radius*2+1)*(kernel_radius*2+1)*bottomchannels; + const int bot1index = ((n * bottomheight) + m) * bottomwidth + (l-pad_size); + bottom1diff[bot1index + item*bottomcount] = sum / (float)sumelems; + } +} + +// == Forward +// == Dimension rearrangement Kernel +template +__global__ void blob_rearrange_kernel2(const Dtype* in, Dtype* out, int num, +int channels, int width, int height, int widthheight, int padding, int pwidthheight) { + // change shape from [batchsize,channel,y,x] to [batchsize,y,x,channel] + int xy = blockIdx.x*blockDim.x + threadIdx.x; + if(xy>=widthheight) + return; + + int ch = blockIdx.y; + int n = blockIdx.z; + + float value=in[(n*channels+ch)*widthheight+xy]; + + __syncthreads(); + + int xpad = (xy % width + padding); + int ypad = (xy / width + 0); + int xypad = ypad * (width+2*padding) + xpad; + + out[(n*pwidthheight+xypad)*channels + ch] = value; +} +template +void Forward_gpu( + const Tensor &out, + const Tensor &data1, + const Tensor &data2, + const Tensor &tmp1, + const Tensor &tmp2, + int top_channels_, int top_height_, int top_width_, int pad_size_, + int single_side, int max_displacement_, int kernel_size_, + int neighborhood_grid_radius_, int neighborhood_grid_width_, + int kernel_radius_, int stride1_, int stride2_, cudaStream_t stream, + cudaStream_t stream_tmp1, cudaStream_t stream_tmp2) { + const Dtype *bottom_data1 = data1.dptr_; + const Dtype *bottom_data2 = data2.dptr_; + Dtype *rbot1 = tmp1.dptr_; + Dtype *rbot2 = tmp2.dptr_; + Dtype *top = out.dptr_; + const int bnum = data1.size(0); + const int bchannels = data1.size(1); + const int bheight = data1.size(2); + const int bwidth = data1.size(3); + const int bwidthheight = bwidth * bheight; + const int topcount = top_width_ * top_height_ * top_channels_; + dim3 threadsPerBlock(THREADS_PER_WARP * WARPS_PER_BLOCK); + int threads_per_block = 16; + dim3 totalBlocksRearr((bwidthheight - 1) / threads_per_block + 1, bchannels, bnum); + const int pwidthheight = (bwidth + 2 * pad_size_) * (bheight); + blob_rearrange_kernel2<<>> + (bottom_data1, rbot1, bnum, bchannels, bwidth, bheight, bwidthheight, pad_size_, pwidthheight); + blob_rearrange_kernel2<<>> + (bottom_data2, rbot2, bnum, bchannels, bwidth, bheight, bwidthheight, pad_size_, pwidthheight); + const int num = bnum; + const int channels = bchannels; + const int height = bheight; + const int width = bwidth + 2 * pad_size_; + const int shared_memory_per_block = (kernel_size_ * kernel_size_) * bchannels; + + int x_shift = - neighborhood_grid_radius_; + if(single_side == -1) { // to the left + x_shift = -neighborhood_grid_width_; + } else if(single_side == 1) { // to the right + x_shift = 0; + } + // correlation1DLayer + int topThreadCount = topcount; + dim3 totalBlocksCorr(top_width_, top_height_, num); + + Correlate1DData<<>>( + topThreadCount, + num, top_width_, top_height_, top_channels_, topcount, + max_displacement_, x_shift, + neighborhood_grid_width_, kernel_radius_, kernel_size_, + stride1_, stride2_, + width, height, channels, + rbot1, rbot2, top); + correlation1D_CUDA_CHECK(cudaPeekAtLastError()); + } +template +void Backward_gpu( + const Tensor &out_grad, + const Tensor &in_grad1, + const Tensor &in_grad2, + const Tensor &tmp1, + const Tensor &tmp2, + int top_channels_, int top_height_, + int top_width_, int pad_size_, int single_side, + int max_displacement_, int kernel_size_, + int neighborhood_grid_radius_, int neighborhood_grid_width_, + int kernel_radius_, int stride1_, int stride2_, + cudaStream_t stream0, cudaStream_t stream1, + int num, int channels, int height, int width) { + // Get top diff, compute bottom diff + const Dtype* top_diff = out_grad.dptr_; + Dtype* bottom0_diff = in_grad1.dptr_; + Dtype* bottom1_diff = in_grad2.dptr_; + const Dtype* rbot1 = tmp1.dptr_; + const Dtype* rbot2 = tmp2.dptr_; + const int paddedheight = height ; + const int paddedwidth = width + 2 * pad_size_; + const int bottomcount = channels * height * width; + int botThreadCount = bottomcount; + const int gridSize = (botThreadCount + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock; + // correlation1DLayerBackward + + int x_shift = - neighborhood_grid_radius_; + if (single_side == -1) { // to the left + x_shift = -neighborhood_grid_width_; + } else if(single_side == 1) { // to the right + x_shift = 0; + } + + // == Run kernel Backward 0 + dim3 totalBlocksBackward0(width, height, channels * num); // First dim is fastest + const int buffer_size_backw0 = \ + (static_cast(ceil(static_cast(2 * kernel_radius_)\ + / static_cast(stride1_))) + 1) * top_channels_; + // == Run kernel Backward 0 + for (int n = 0; n < num; n++) { + Correlate1DDataBackward0<<>>( + botThreadCount, + num, n, top_width_, top_height_, top_channels_, + max_displacement_, x_shift, neighborhood_grid_width_, kernel_radius_, + stride1_, stride2_, + width, height, paddedwidth, paddedheight, channels, bottomcount, pad_size_, + bottom0_diff, rbot2, top_diff); + correlation1D_CUDA_CHECK(cudaPeekAtLastError()); + } + // == Run kernel Backward 1 + for (int n = 0; n < num; n++) { + Correlate1DDataBackward1<<>>( + botThreadCount, + num, n, top_width_, top_height_, top_channels_, + max_displacement_, x_shift, neighborhood_grid_width_, kernel_radius_, + stride1_, stride2_, + width, height, paddedwidth, paddedheight, channels, bottomcount, pad_size_, + rbot1, bottom1_diff, top_diff); + correlation1D_CUDA_CHECK(cudaPeekAtLastError()); + } +} +} // namespace cuda +template +inline void Correlation1DForward(const Tensor &out, + const Tensor &data1, + const Tensor &data2, + const Tensor &tmp1, + const Tensor &tmp2, + int top_channels_, int top_height_, + int top_width_, int pad_size_, int single_side, + int max_displacement_, int kernel_size_, + int neighborhood_grid_radius_, int neighborhood_grid_width_, + int kernel_radius_, int stride1_, int stride2_ + ) { + cudaStream_t stream = Stream::GetStream(out.stream_); + cudaStream_t stream_tmp1 = Stream::GetStream(tmp1.stream_); + cudaStream_t stream_tmp2 = Stream::GetStream(tmp2.stream_); + cuda::Forward_gpu(out, data1, data2, tmp1, tmp2, top_channels_, top_height_, + top_width_, pad_size_, single_side, max_displacement_, kernel_size_, + neighborhood_grid_radius_, neighborhood_grid_width_, kernel_radius_, + stride1_, stride2_, stream, stream_tmp1, stream_tmp2); +} + +template +inline void Correlation1DBackward(const Tensor &out_grad, + const Tensor &in_grad1, + const Tensor &in_grad2, + const Tensor &tmp1, + const Tensor &tmp2, + int top_channels_, int top_height_, + int top_width_, int pad_size_, int single_side, + int max_displacement_, int kernel_size_, + int neighborhood_grid_radius_, int neighborhood_grid_width_, + int kernel_radius_, int stride1_, + int stride2_, int num, int channels, int height, int width + ) { + cudaStream_t stream0 = Stream::GetStream(in_grad1.stream_); + cudaStream_t stream1 = Stream::GetStream(in_grad2.stream_); + cuda::Backward_gpu(out_grad, in_grad1, in_grad2, tmp1, tmp2, top_channels_, + top_height_, top_width_, pad_size_, single_side, + max_displacement_, kernel_size_, neighborhood_grid_radius_, + neighborhood_grid_width_, kernel_radius_, stride1_, stride2_, + stream0, stream1, num, channels, height, width); +} +} // namespace mshadow +namespace mxnet { +namespace op { +template<> +Operator* CreateOp(Correlation1DParam param) { + return new Correlation1DOp(param); +} +} // namespace op +} // namespace mxnet