From 4da6c55610eb0b59f92a1aff47ec7af8e526b060 Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Tue, 28 Aug 2018 22:59:33 +0000 Subject: [PATCH] Add support for more req patterns for bilinear sampler backward --- src/operator/bilinear_sampler-inl.h | 11 ++- src/operator/bilinear_sampler.cc | 90 +++++++++++++------------ src/operator/bilinear_sampler.cu | 100 ++++++++++++++++------------ 3 files changed, 108 insertions(+), 93 deletions(-) diff --git a/src/operator/bilinear_sampler-inl.h b/src/operator/bilinear_sampler-inl.h index 657aebafdb74..a872146507f2 100644 --- a/src/operator/bilinear_sampler-inl.h +++ b/src/operator/bilinear_sampler-inl.h @@ -92,19 +92,16 @@ class BilinearSamplerOp : public Operator { Tensor gdata = in_grad[bs::kData].get(s); Tensor ggrid = in_grad[bs::kGrid].get(s); Tensor grad = out_grad[bs::kOut].get(s); - if (req[bs::kData] != kNullOp && req[bs::kGrid] != kNullOp) { + if (req[bs::kData] == kNullOp && req[bs::kGrid] == kNullOp) { + return; + } else { if (req[bs::kData] == kWriteTo) { gdata = scalar(0.0f); } if (req[bs::kGrid] == kWriteTo) { ggrid = scalar(0.0f); } - BilinearSamplerBackward(gdata, ggrid, grad, data, grid); - } else if (req[bs::kData] == kNullOp && req[bs::kGrid] == kNullOp) { - return; - } else { - LOG(FATAL) << "Have not implemented the data req combinations! gdata_req=" - << req[bs::kData] << " ggrid_req=" << req[bs::kGrid]; + BilinearSamplerBackward(gdata, ggrid, grad, data, grid, req[bs::kData], req[bs::kGrid]); } } diff --git a/src/operator/bilinear_sampler.cc b/src/operator/bilinear_sampler.cc index 3365d98bb4db..bafa8bd97c43 100644 --- a/src/operator/bilinear_sampler.cc +++ b/src/operator/bilinear_sampler.cc @@ -78,10 +78,12 @@ inline void BilinearSamplerForward(const Tensor &output, template inline void BilinearSamplerBackward(const Tensor &gdata, - const Tensor &ggrid, - const Tensor &output_grad, - const Tensor &input_data, - const Tensor &grid) { + const Tensor &ggrid, + const Tensor &output_grad, + const Tensor &input_data, + const Tensor &grid, + const mxnet::OpReqType data_req, + const mxnet::OpReqType grid_req) { DType *g_input = gdata.dptr_; DType *grad_grid = ggrid.dptr_; const DType *grid_src = grid.dptr_; @@ -102,46 +104,50 @@ inline void BilinearSamplerBackward(const Tensor &gdata, int top_left_x = static_cast(floor(x_real)); DType top_left_y_w = 1.0 - (y_real - top_left_y); DType top_left_x_w = 1.0 - (x_real - top_left_x); - for (index_t c = 0; c < static_cast(o_c); ++c) { - index_t grad_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w; - int data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w - + top_left_x; - // calc 4 vertex value in input data - DType top_left_v = 0; - DType top_right_v = 0; - DType bottom_left_v = 0; - DType bottom_right_v = 0; - // calc input grad - if (between(top_left_x, 0, i_w-1) && between(top_left_y, 0, i_h-1)) { - *(g_input + data_index) += *(grad + grad_index) * top_left_y_w * top_left_x_w; - top_left_v = *(data + data_index); + if (data_req != mxnet::kNullOp) { + for (index_t c = 0; c < static_cast(o_c); ++c) { + index_t grad_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w; + int data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w + + top_left_x; + // calc 4 vertex value in input data + DType top_left_v = 0; + DType top_right_v = 0; + DType bottom_left_v = 0; + DType bottom_right_v = 0; + // calc input grad + if (between(top_left_x, 0, i_w-1) && between(top_left_y, 0, i_h-1)) { + *(g_input + data_index) += *(grad + grad_index) * top_left_y_w * top_left_x_w; + top_left_v = *(data + data_index); + } + if (between(top_left_x+1, 0, i_w-1) && between(top_left_y, 0, i_h-1)) { + *(g_input + data_index + 1) += *(grad + grad_index) * top_left_y_w + * (1.0 - top_left_x_w); + top_right_v = *(data + data_index + 1); + } + if (between(top_left_x, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) { + *(g_input + data_index+ i_w) += *(grad + grad_index) * (1.0 - top_left_y_w) + * top_left_x_w; + bottom_left_v = *(data + data_index + i_w); + } + if (between(top_left_x+1, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) { + *(g_input + data_index+ i_w + 1) += *(grad + grad_index) * (1.0 - top_left_y_w) + * (1.0 - top_left_x_w); + bottom_right_v = *(data + data_index + i_w + 1); + } + // calc weight grad of top_left_w, then multiple -1 is the grad of grid_src + top_left_y_gw -= *(grad + grad_index) * (top_right_v - bottom_right_v + + (top_left_v - top_right_v - bottom_left_v + bottom_right_v) + * top_left_x_w); + top_left_x_gw -= *(grad + grad_index) * (bottom_left_v - bottom_right_v + + (top_left_v - top_right_v - bottom_left_v + bottom_right_v) + * top_left_y_w); } - if (between(top_left_x+1, 0, i_w-1) && between(top_left_y, 0, i_h-1)) { - *(g_input + data_index + 1) += *(grad + grad_index) * top_left_y_w - * (1.0 - top_left_x_w); - top_right_v = *(data + data_index + 1); - } - if (between(top_left_x, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) { - *(g_input + data_index+ i_w) += *(grad + grad_index) * (1.0 - top_left_y_w) - * top_left_x_w; - bottom_left_v = *(data + data_index + i_w); - } - if (between(top_left_x+1, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) { - *(g_input + data_index+ i_w + 1) += *(grad + grad_index) * (1.0 - top_left_y_w) - * (1.0 - top_left_x_w); - bottom_right_v = *(data + data_index + i_w + 1); - } - // calc weight grad of top_left_w, then multiple -1 is the grad of grid_src - top_left_y_gw -= *(grad + grad_index) * (top_right_v - bottom_right_v + - (top_left_v - top_right_v - bottom_left_v + bottom_right_v) - * top_left_x_w); - top_left_x_gw -= *(grad + grad_index) * (bottom_left_v - bottom_right_v + - (top_left_v - top_right_v - bottom_left_v + bottom_right_v) - * top_left_y_w); } - // calc grad of grid - *(grad_grid + grid_src_index + o_h * o_w) += top_left_y_gw * (i_h - 1) / 2; - *(grad_grid + grid_src_index) += top_left_x_gw * (i_w - 1) / 2; + if (grid_req != mxnet::kNullOp) { + // calc grad of grid + *(grad_grid + grid_src_index + o_h * o_w) += top_left_y_gw * (i_h - 1) / 2; + *(grad_grid + grid_src_index) += top_left_x_gw * (i_w - 1) / 2; + } } } } diff --git a/src/operator/bilinear_sampler.cu b/src/operator/bilinear_sampler.cu index 0ab628da700b..cd67fe561496 100644 --- a/src/operator/bilinear_sampler.cu +++ b/src/operator/bilinear_sampler.cu @@ -79,7 +79,7 @@ __global__ void BilinearSamplerForwardKernel(const int i_c, const int i_h, } } -template +template __global__ void BilinearSamplerBackwardKernel(const int i_c, const int i_h, const int i_w, const DType* grad, const DType* data, const int o_n, @@ -104,45 +104,49 @@ __global__ void BilinearSamplerBackwardKernel(const int i_c, const int i_h, int top_left_x = static_cast(floor(x_real)); DType top_left_y_w = 1.0 - (y_real - top_left_y); DType top_left_x_w = 1.0 - (x_real - top_left_x); - for (index_t c = 0; c < o_c; ++c) { - index_t grad_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w; - int data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w + top_left_x; - // calc 4 vertex value in input data - DType top_left_v = 0; - DType top_right_v = 0; - DType bottom_left_v = 0; - DType bottom_right_v = 0; - // calc input grad - if (between(top_left_x, 0, i_w-1) && between(top_left_y, 0, i_h-1)) { - atomicAdd(&g_input[data_index], *(grad + grad_index) * top_left_y_w * top_left_x_w); - top_left_v = *(data + data_index); - } - if (between(top_left_x+1, 0, i_w-1) && between(top_left_y, 0, i_h-1)) { - atomicAdd(&g_input[data_index + 1], *(grad + grad_index) * top_left_y_w - * (1.0 - top_left_x_w)); - top_right_v = *(data + data_index + 1); - } - if (between(top_left_x, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) { - atomicAdd(&g_input[data_index+ i_w], *(grad + grad_index) * (1.0 - top_left_y_w) - * top_left_x_w); - bottom_left_v = *(data + data_index + i_w); + if (Req1 != mxnet::kNullOp) { + for (index_t c = 0; c < o_c; ++c) { + index_t grad_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w; + int data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w + top_left_x; + // calc 4 vertex value in input data + DType top_left_v = 0; + DType top_right_v = 0; + DType bottom_left_v = 0; + DType bottom_right_v = 0; + // calc input grad + if (between(top_left_x, 0, i_w-1) && between(top_left_y, 0, i_h-1)) { + atomicAdd(&g_input[data_index], *(grad + grad_index) * top_left_y_w * top_left_x_w); + top_left_v = *(data + data_index); + } + if (between(top_left_x+1, 0, i_w-1) && between(top_left_y, 0, i_h-1)) { + atomicAdd(&g_input[data_index + 1], *(grad + grad_index) * top_left_y_w + * (1.0 - top_left_x_w)); + top_right_v = *(data + data_index + 1); + } + if (between(top_left_x, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) { + atomicAdd(&g_input[data_index+ i_w], *(grad + grad_index) * (1.0 - top_left_y_w) + * top_left_x_w); + bottom_left_v = *(data + data_index + i_w); + } + if (between(top_left_x+1, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) { + atomicAdd(&g_input[data_index+ i_w + 1], *(grad + grad_index) * (1.0 - top_left_y_w) + * (1.0 - top_left_x_w)); + bottom_right_v = *(data + data_index + i_w + 1); + } + // calc weight grad of top_left_w, then multiple -1 is the grad of grid_src + top_left_y_gw -= *(grad + grad_index) * (top_right_v - bottom_right_v + + (top_left_v - top_right_v - bottom_left_v + bottom_right_v) + * top_left_x_w); + top_left_x_gw -= *(grad + grad_index) * (bottom_left_v - bottom_right_v + + (top_left_v - top_right_v - bottom_left_v + bottom_right_v) + * top_left_y_w); } - if (between(top_left_x+1, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) { - atomicAdd(&g_input[data_index+ i_w + 1], *(grad + grad_index) * (1.0 - top_left_y_w) - * (1.0 - top_left_x_w)); - bottom_right_v = *(data + data_index + i_w + 1); - } - // calc weight grad of top_left_w, then multiple -1 is the grad of grid_src - top_left_y_gw -= *(grad + grad_index) * (top_right_v - bottom_right_v + - (top_left_v - top_right_v - bottom_left_v + bottom_right_v) - * top_left_x_w); - top_left_x_gw -= *(grad + grad_index) * (bottom_left_v - bottom_right_v + - (top_left_v - top_right_v - bottom_left_v + bottom_right_v) - * top_left_y_w); } - // calc grad of grid - *(grad_grid + grid_src_index + o_h * o_w) += top_left_y_gw * (i_h - 1) / 2; - *(grad_grid + grid_src_index) += top_left_x_gw * (i_w - 1) / 2; + if (Req2 != mxnet::kNullOp) { + // calc grad of grid + *(grad_grid + grid_src_index + o_h * o_w) += top_left_y_gw * (i_h - 1) / 2; + *(grad_grid + grid_src_index) += top_left_x_gw * (i_w - 1) / 2; + } } } } // namespace cuda @@ -174,10 +178,13 @@ inline void BilinearSamplerForward(const Tensor &output, template inline void BilinearSamplerBackward(const Tensor &input_grad, - const Tensor &ggrid, - const Tensor &output_grad, - const Tensor &input_data, - const Tensor &grid) { + const Tensor &ggrid, + const Tensor &output_grad, + const Tensor &input_data, + const Tensor &grid, + const mxnet::OpReqType data_req, + const mxnet::OpReqType grid_req) { + using namespace mxnet; DType *g_input = input_grad.dptr_; DType *grad_grid = ggrid.dptr_; const DType *grid_src = grid.dptr_; @@ -196,8 +203,13 @@ inline void BilinearSamplerBackward(const Tensor &input_grad, dim3 threads_per_block(kMaxThreadsPerBlock); CheckLaunchParam(num_blocks, threads_per_block, "bilinear sampler backward"); cudaStream_t stream = Stream::GetStream(input_grad.stream_); - cuda::BilinearSamplerBackwardKernel << > >( - i_c, i_h, i_w, grad, data, o_n, o_c, o_h, o_w, g_input, grid_src, grad_grid); + MXNET_ASSIGN_REQ_SWITCH(data_req, Req1, { + MXNET_ASSIGN_REQ_SWITCH(grid_req, Req2, { + cuda::BilinearSamplerBackwardKernel + <<>>( + i_c, i_h, i_w, grad, data, o_n, o_c, o_h, o_w, g_input, grid_src, grad_grid); + }); + }); // post kernel check cudaError err = cudaPeekAtLastError(); CHECK_EQ(err, cudaSuccess) << cudaGetErrorString(err);