diff --git a/src/operator/bilinear_sampler-inl.h b/src/operator/bilinear_sampler-inl.h index e0b4db7b367c..499d23396207 100644 --- a/src/operator/bilinear_sampler-inl.h +++ b/src/operator/bilinear_sampler-inl.h @@ -95,19 +95,16 @@ class BilinearSamplerOp : public Operator { Tensor gdata = in_grad[bs::kData].get(s); Tensor ggrid = in_grad[bs::kGrid].get(s); Tensor grad = out_grad[bs::kOut].get(s); - if (req[bs::kData] != kNullOp && req[bs::kGrid] != kNullOp) { + if (req[bs::kData] == kNullOp && req[bs::kGrid] == kNullOp) { + return; + } else { if (req[bs::kData] == kWriteTo) { gdata = scalar(0.0f); } if (req[bs::kGrid] == kWriteTo) { ggrid = scalar(0.0f); } - BilinearSamplerBackward(gdata, ggrid, grad, data, grid); - } else if (req[bs::kData] == kNullOp && req[bs::kGrid] == kNullOp) { - return; - } else { - LOG(FATAL) << "Have not implemented the data req combinations! gdata_req=" - << req[bs::kData] << " ggrid_req=" << req[bs::kGrid]; + BilinearSamplerBackward(gdata, ggrid, grad, data, grid, req[bs::kData], req[bs::kGrid]); } } diff --git a/src/operator/bilinear_sampler.cc b/src/operator/bilinear_sampler.cc index 3365d98bb4db..a3b7d5764245 100644 --- a/src/operator/bilinear_sampler.cc +++ b/src/operator/bilinear_sampler.cc @@ -78,10 +78,12 @@ inline void BilinearSamplerForward(const Tensor &output, template inline void BilinearSamplerBackward(const Tensor &gdata, - const Tensor &ggrid, - const Tensor &output_grad, - const Tensor &input_data, - const Tensor &grid) { + const Tensor &ggrid, + const Tensor &output_grad, + const Tensor &input_data, + const Tensor &grid, + const mxnet::OpReqType data_req, + const mxnet::OpReqType grid_req) { DType *g_input = gdata.dptr_; DType *grad_grid = ggrid.dptr_; const DType *grid_src = grid.dptr_; @@ -104,8 +106,7 @@ inline void BilinearSamplerBackward(const Tensor &gdata, DType top_left_x_w = 1.0 - (x_real - top_left_x); for (index_t c = 0; c < static_cast(o_c); ++c) { index_t grad_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w; - int data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w - + top_left_x; + int data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w + top_left_x; // calc 4 vertex value in input data DType top_left_v = 0; DType top_right_v = 0; @@ -113,22 +114,30 @@ inline void BilinearSamplerBackward(const Tensor &gdata, DType bottom_right_v = 0; // calc input grad if (between(top_left_x, 0, i_w-1) && between(top_left_y, 0, i_h-1)) { - *(g_input + data_index) += *(grad + grad_index) * top_left_y_w * top_left_x_w; + if (data_req != mxnet::kNullOp) { + *(g_input + data_index) += *(grad + grad_index) * top_left_y_w * top_left_x_w; + } top_left_v = *(data + data_index); } if (between(top_left_x+1, 0, i_w-1) && between(top_left_y, 0, i_h-1)) { - *(g_input + data_index + 1) += *(grad + grad_index) * top_left_y_w - * (1.0 - top_left_x_w); + if (data_req != mxnet::kNullOp) { + *(g_input + data_index + 1) += + *(grad + grad_index) * top_left_y_w * (1.0 - top_left_x_w); + } top_right_v = *(data + data_index + 1); } if (between(top_left_x, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) { - *(g_input + data_index+ i_w) += *(grad + grad_index) * (1.0 - top_left_y_w) - * top_left_x_w; + if (data_req != mxnet::kNullOp) { + *(g_input + data_index+ i_w) += + *(grad + grad_index) * (1.0 - top_left_y_w) * top_left_x_w; + } bottom_left_v = *(data + data_index + i_w); } if (between(top_left_x+1, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) { - *(g_input + data_index+ i_w + 1) += *(grad + grad_index) * (1.0 - top_left_y_w) - * (1.0 - top_left_x_w); + if (data_req != mxnet::kNullOp) { + *(g_input + data_index+ i_w + 1) += + *(grad + grad_index) * (1.0 - top_left_y_w) * (1.0 - top_left_x_w); + } bottom_right_v = *(data + data_index + i_w + 1); } // calc weight grad of top_left_w, then multiple -1 is the grad of grid_src @@ -139,9 +148,11 @@ inline void BilinearSamplerBackward(const Tensor &gdata, (top_left_v - top_right_v - bottom_left_v + bottom_right_v) * top_left_y_w); } - // calc grad of grid - *(grad_grid + grid_src_index + o_h * o_w) += top_left_y_gw * (i_h - 1) / 2; - *(grad_grid + grid_src_index) += top_left_x_gw * (i_w - 1) / 2; + if (grid_req != mxnet::kNullOp) { + // calc grad of grid + *(grad_grid + grid_src_index + o_h * o_w) += top_left_y_gw * (i_h - 1) / 2; + *(grad_grid + grid_src_index) += top_left_x_gw * (i_w - 1) / 2; + } } } } diff --git a/src/operator/bilinear_sampler.cu b/src/operator/bilinear_sampler.cu index e1f205258a24..2e6be3e1ef3e 100644 --- a/src/operator/bilinear_sampler.cu +++ b/src/operator/bilinear_sampler.cu @@ -79,7 +79,7 @@ __global__ void BilinearSamplerForwardKernel(const int i_c, const int i_h, } } -template +template __global__ void BilinearSamplerBackwardKernel(const int i_c, const int i_h, const int i_w, const DType* grad, const DType* data, const int o_n, @@ -114,22 +114,30 @@ __global__ void BilinearSamplerBackwardKernel(const int i_c, const int i_h, DType bottom_right_v = 0; // calc input grad if (between(top_left_x, 0, i_w-1) && between(top_left_y, 0, i_h-1)) { - atomicAdd(&g_input[data_index], *(grad + grad_index) * top_left_y_w * top_left_x_w); + if (Req1 != mxnet::kNullOp) { + atomicAdd(&g_input[data_index], *(grad + grad_index) * top_left_y_w * top_left_x_w); + } top_left_v = *(data + data_index); } if (between(top_left_x+1, 0, i_w-1) && between(top_left_y, 0, i_h-1)) { - atomicAdd(&g_input[data_index + 1], *(grad + grad_index) * top_left_y_w - * (1.0 - top_left_x_w)); + if (Req1 != mxnet::kNullOp) { + atomicAdd(&g_input[data_index + 1], + *(grad + grad_index) * top_left_y_w * (1.0 - top_left_x_w)); + } top_right_v = *(data + data_index + 1); } if (between(top_left_x, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) { - atomicAdd(&g_input[data_index+ i_w], *(grad + grad_index) * (1.0 - top_left_y_w) - * top_left_x_w); + if (Req1 != mxnet::kNullOp) { + atomicAdd(&g_input[data_index+ i_w], + *(grad + grad_index) * (1.0 - top_left_y_w) * top_left_x_w); + } bottom_left_v = *(data + data_index + i_w); } if (between(top_left_x+1, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) { - atomicAdd(&g_input[data_index+ i_w + 1], *(grad + grad_index) * (1.0 - top_left_y_w) - * (1.0 - top_left_x_w)); + if (Req1 != mxnet::kNullOp) { + atomicAdd(&g_input[data_index+ i_w + 1], + *(grad + grad_index) * (1.0 - top_left_y_w) * (1.0 - top_left_x_w)); + } bottom_right_v = *(data + data_index + i_w + 1); } // calc weight grad of top_left_w, then multiple -1 is the grad of grid_src @@ -140,9 +148,11 @@ __global__ void BilinearSamplerBackwardKernel(const int i_c, const int i_h, (top_left_v - top_right_v - bottom_left_v + bottom_right_v) * top_left_y_w); } - // calc grad of grid - *(grad_grid + grid_src_index + o_h * o_w) += top_left_y_gw * (i_h - 1) / 2; - *(grad_grid + grid_src_index) += top_left_x_gw * (i_w - 1) / 2; + if (Req2 != mxnet::kNullOp) { + // calc grad of grid + *(grad_grid + grid_src_index + o_h * o_w) += top_left_y_gw * (i_h - 1) / 2; + *(grad_grid + grid_src_index) += top_left_x_gw * (i_w - 1) / 2; + } } } } // namespace cuda @@ -174,10 +184,13 @@ inline void BilinearSamplerForward(const Tensor &output, template inline void BilinearSamplerBackward(const Tensor &input_grad, - const Tensor &ggrid, - const Tensor &output_grad, - const Tensor &input_data, - const Tensor &grid) { + const Tensor &ggrid, + const Tensor &output_grad, + const Tensor &input_data, + const Tensor &grid, + const mxnet::OpReqType data_req, + const mxnet::OpReqType grid_req) { + using namespace mxnet; DType *g_input = input_grad.dptr_; DType *grad_grid = ggrid.dptr_; const DType *grid_src = grid.dptr_; @@ -196,8 +209,13 @@ inline void BilinearSamplerBackward(const Tensor &input_grad, dim3 threads_per_block(kMaxThreadsPerBlock); CheckLaunchParam(num_blocks, threads_per_block, "bilinear sampler backward"); cudaStream_t stream = Stream::GetStream(input_grad.stream_); - cuda::BilinearSamplerBackwardKernel << > >( - i_c, i_h, i_w, grad, data, o_n, o_c, o_h, o_w, g_input, grid_src, grad_grid); + MXNET_REQ_TYPE_SWITCH(data_req, Req1, { + MXNET_REQ_TYPE_SWITCH(grid_req, Req2, { + cuda::BilinearSamplerBackwardKernel + <<>>( + i_c, i_h, i_w, grad, data, o_n, o_c, o_h, o_w, g_input, grid_src, grad_grid); + }); + }); // post kernel check cudaError err = cudaPeekAtLastError(); CHECK_EQ(err, cudaSuccess) << cudaGetErrorString(err); diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index f11a497c564c..e77569671ebb 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -111,6 +111,33 @@ inline int get_num_threads(const int N) { } +/*! \brief operator request type switch */ +#define MXNET_REQ_TYPE_SWITCH(req, ReqType, ...) \ + switch (req) { \ + case kNullOp: \ + { \ + const OpReqType ReqType = kNullOp; \ + {__VA_ARGS__} \ + } \ + break; \ + case kWriteInplace: \ + case kWriteTo: \ + { \ + const OpReqType ReqType = kWriteTo; \ + {__VA_ARGS__} \ + } \ + break; \ + case kAddTo: \ + { \ + const OpReqType ReqType = kAddTo; \ + {__VA_ARGS__} \ + } \ + break; \ + default: \ + break; \ + } + + #define MXNET_NDIM_SWITCH(NDim, ndim, ...) \ if (NDim == 0) { \ } else if (NDim == 1) { \ diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 1fc2c8e922d9..d201a2e09c6d 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -1945,7 +1945,7 @@ def test_bilinear_sampler_versions(): exe.arg_dict['data'][:] = test_data exe.arg_dict['grid'][:] = test_grid exe.forward(is_train=True) - assert_almost_equal(exe_list[0].outputs[0].asnumpy(), exe.outputs[0].asnumpy(), rtol=1e-3, atol=1e-5) + assert_almost_equal(exe_list[ref_idx].outputs[0].asnumpy(), exe.outputs[0].asnumpy(), rtol=1e-3, atol=1e-5) out_grad = np.random.uniform(low=-0.01, high=0.01,size=data_shape[:2] + grid_shape[2:]).astype(np.float32) for exe in exe_list: @@ -1975,6 +1975,22 @@ def test_bilinear_sampler_versions(): assert_almost_equal(exe_list[ref_idx].grad_dict['data'].asnumpy(), data_grad + data_initial_grad, rtol=1e-3, atol=1e-5) assert_almost_equal(exe_list[ref_idx].grad_dict['grid'].asnumpy(), grid_grad + grid_initial_grad, rtol=1e-3, atol=1e-5) + for req_dict in [{'data' : 'null', 'grid' : 'write'}, {'data' : 'write', 'grid' : 'null'}]: + # Mixture of kWriteTo and kNullOp + exe_cpu_mix = sym1.simple_bind(data=data_shape, grid=grid_shape, ctx=mx.cpu(), grad_req=req_dict) + exe_gpu_mix = sym2.simple_bind(data=data_shape, grid=grid_shape, ctx=default_context(), grad_req=req_dict) + exe_cudnn_mix = sym3.simple_bind(data=data_shape, grid=grid_shape, ctx=default_context(), grad_req=req_dict) + exe_list = [exe_cpu_mix, exe_gpu_mix, exe_cudnn_mix] + for exe in exe_list: + exe.arg_dict['data'][:] = test_data + exe.arg_dict['grid'][:] = test_grid + exe.forward(is_train=True) + exe.backward(mx.nd.array(out_grad)) + if req_dict['data'] is 'write': + assert_almost_equal(exe.grad_dict['data'].asnumpy(), exe_list[ref_idx].grad_dict['data'].asnumpy(), rtol=1e-3, atol=1e-5) + if req_dict['grid'] is 'write': + assert_almost_equal(exe.grad_dict['grid'].asnumpy(), exe_list[ref_idx].grad_dict['grid'].asnumpy(), rtol=1e-3, atol=1e-5) + def test_context_num_gpus(): # Test that num_gpus reports at least one GPU, as the test is run on a GPU host.