Skip to content

Commit

Permalink
Add support for more req patterns for bilinear sampler backward (apac…
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 authored and anirudh2290 committed Sep 19, 2018
1 parent 61c5a5c commit 3df22a8
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 41 deletions.
11 changes: 4 additions & 7 deletions src/operator/bilinear_sampler-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,19 +95,16 @@ class BilinearSamplerOp : public Operator {
Tensor<xpu, 4, DType> gdata = in_grad[bs::kData].get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> ggrid = in_grad[bs::kGrid].get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> grad = out_grad[bs::kOut].get<xpu, 4, DType>(s);
if (req[bs::kData] != kNullOp && req[bs::kGrid] != kNullOp) {
if (req[bs::kData] == kNullOp && req[bs::kGrid] == kNullOp) {
return;
} else {
if (req[bs::kData] == kWriteTo) {
gdata = scalar<DType>(0.0f);
}
if (req[bs::kGrid] == kWriteTo) {
ggrid = scalar<DType>(0.0f);
}
BilinearSamplerBackward(gdata, ggrid, grad, data, grid);
} else if (req[bs::kData] == kNullOp && req[bs::kGrid] == kNullOp) {
return;
} else {
LOG(FATAL) << "Have not implemented the data req combinations! gdata_req="
<< req[bs::kData] << " ggrid_req=" << req[bs::kGrid];
BilinearSamplerBackward(gdata, ggrid, grad, data, grid, req[bs::kData], req[bs::kGrid]);
}
}

Expand Down
43 changes: 27 additions & 16 deletions src/operator/bilinear_sampler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,12 @@ inline void BilinearSamplerForward(const Tensor<cpu, 4, DType> &output,

template<typename DType>
inline void BilinearSamplerBackward(const Tensor<cpu, 4, DType> &gdata,
const Tensor<cpu, 4, DType> &ggrid,
const Tensor<cpu, 4, DType> &output_grad,
const Tensor<cpu, 4, DType> &input_data,
const Tensor<cpu, 4, DType> &grid) {
const Tensor<cpu, 4, DType> &ggrid,
const Tensor<cpu, 4, DType> &output_grad,
const Tensor<cpu, 4, DType> &input_data,
const Tensor<cpu, 4, DType> &grid,
const mxnet::OpReqType data_req,
const mxnet::OpReqType grid_req) {
DType *g_input = gdata.dptr_;
DType *grad_grid = ggrid.dptr_;
const DType *grid_src = grid.dptr_;
Expand All @@ -104,31 +106,38 @@ inline void BilinearSamplerBackward(const Tensor<cpu, 4, DType> &gdata,
DType top_left_x_w = 1.0 - (x_real - top_left_x);
for (index_t c = 0; c < static_cast<index_t>(o_c); ++c) {
index_t grad_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w;
int data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w
+ top_left_x;
int data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w + top_left_x;
// calc 4 vertex value in input data
DType top_left_v = 0;
DType top_right_v = 0;
DType bottom_left_v = 0;
DType bottom_right_v = 0;
// calc input grad
if (between(top_left_x, 0, i_w-1) && between(top_left_y, 0, i_h-1)) {
*(g_input + data_index) += *(grad + grad_index) * top_left_y_w * top_left_x_w;
if (data_req != mxnet::kNullOp) {
*(g_input + data_index) += *(grad + grad_index) * top_left_y_w * top_left_x_w;
}
top_left_v = *(data + data_index);
}
if (between(top_left_x+1, 0, i_w-1) && between(top_left_y, 0, i_h-1)) {
*(g_input + data_index + 1) += *(grad + grad_index) * top_left_y_w
* (1.0 - top_left_x_w);
if (data_req != mxnet::kNullOp) {
*(g_input + data_index + 1) +=
*(grad + grad_index) * top_left_y_w * (1.0 - top_left_x_w);
}
top_right_v = *(data + data_index + 1);
}
if (between(top_left_x, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) {
*(g_input + data_index+ i_w) += *(grad + grad_index) * (1.0 - top_left_y_w)
* top_left_x_w;
if (data_req != mxnet::kNullOp) {
*(g_input + data_index+ i_w) +=
*(grad + grad_index) * (1.0 - top_left_y_w) * top_left_x_w;
}
bottom_left_v = *(data + data_index + i_w);
}
if (between(top_left_x+1, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) {
*(g_input + data_index+ i_w + 1) += *(grad + grad_index) * (1.0 - top_left_y_w)
* (1.0 - top_left_x_w);
if (data_req != mxnet::kNullOp) {
*(g_input + data_index+ i_w + 1) +=
*(grad + grad_index) * (1.0 - top_left_y_w) * (1.0 - top_left_x_w);
}
bottom_right_v = *(data + data_index + i_w + 1);
}
// calc weight grad of top_left_w, then multiple -1 is the grad of grid_src
Expand All @@ -139,9 +148,11 @@ inline void BilinearSamplerBackward(const Tensor<cpu, 4, DType> &gdata,
(top_left_v - top_right_v - bottom_left_v + bottom_right_v)
* top_left_y_w);
}
// calc grad of grid
*(grad_grid + grid_src_index + o_h * o_w) += top_left_y_gw * (i_h - 1) / 2;
*(grad_grid + grid_src_index) += top_left_x_gw * (i_w - 1) / 2;
if (grid_req != mxnet::kNullOp) {
// calc grad of grid
*(grad_grid + grid_src_index + o_h * o_w) += top_left_y_gw * (i_h - 1) / 2;
*(grad_grid + grid_src_index) += top_left_x_gw * (i_w - 1) / 2;
}
}
}
}
Expand Down
52 changes: 35 additions & 17 deletions src/operator/bilinear_sampler.cu
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ __global__ void BilinearSamplerForwardKernel(const int i_c, const int i_h,
}
}

template<typename DType>
template<typename DType, int Req1, int Req2>
__global__ void BilinearSamplerBackwardKernel(const int i_c, const int i_h,
const int i_w, const DType* grad,
const DType* data, const int o_n,
Expand Down Expand Up @@ -114,22 +114,30 @@ __global__ void BilinearSamplerBackwardKernel(const int i_c, const int i_h,
DType bottom_right_v = 0;
// calc input grad
if (between(top_left_x, 0, i_w-1) && between(top_left_y, 0, i_h-1)) {
atomicAdd(&g_input[data_index], *(grad + grad_index) * top_left_y_w * top_left_x_w);
if (Req1 != mxnet::kNullOp) {
atomicAdd(&g_input[data_index], *(grad + grad_index) * top_left_y_w * top_left_x_w);
}
top_left_v = *(data + data_index);
}
if (between(top_left_x+1, 0, i_w-1) && between(top_left_y, 0, i_h-1)) {
atomicAdd(&g_input[data_index + 1], *(grad + grad_index) * top_left_y_w
* (1.0 - top_left_x_w));
if (Req1 != mxnet::kNullOp) {
atomicAdd(&g_input[data_index + 1],
*(grad + grad_index) * top_left_y_w * (1.0 - top_left_x_w));
}
top_right_v = *(data + data_index + 1);
}
if (between(top_left_x, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) {
atomicAdd(&g_input[data_index+ i_w], *(grad + grad_index) * (1.0 - top_left_y_w)
* top_left_x_w);
if (Req1 != mxnet::kNullOp) {
atomicAdd(&g_input[data_index+ i_w],
*(grad + grad_index) * (1.0 - top_left_y_w) * top_left_x_w);
}
bottom_left_v = *(data + data_index + i_w);
}
if (between(top_left_x+1, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) {
atomicAdd(&g_input[data_index+ i_w + 1], *(grad + grad_index) * (1.0 - top_left_y_w)
* (1.0 - top_left_x_w));
if (Req1 != mxnet::kNullOp) {
atomicAdd(&g_input[data_index+ i_w + 1],
*(grad + grad_index) * (1.0 - top_left_y_w) * (1.0 - top_left_x_w));
}
bottom_right_v = *(data + data_index + i_w + 1);
}
// calc weight grad of top_left_w, then multiple -1 is the grad of grid_src
Expand All @@ -140,9 +148,11 @@ __global__ void BilinearSamplerBackwardKernel(const int i_c, const int i_h,
(top_left_v - top_right_v - bottom_left_v + bottom_right_v)
* top_left_y_w);
}
// calc grad of grid
*(grad_grid + grid_src_index + o_h * o_w) += top_left_y_gw * (i_h - 1) / 2;
*(grad_grid + grid_src_index) += top_left_x_gw * (i_w - 1) / 2;
if (Req2 != mxnet::kNullOp) {
// calc grad of grid
*(grad_grid + grid_src_index + o_h * o_w) += top_left_y_gw * (i_h - 1) / 2;
*(grad_grid + grid_src_index) += top_left_x_gw * (i_w - 1) / 2;
}
}
}
} // namespace cuda
Expand Down Expand Up @@ -174,10 +184,13 @@ inline void BilinearSamplerForward(const Tensor<gpu, 4, DType> &output,

template<typename DType>
inline void BilinearSamplerBackward(const Tensor<gpu, 4, DType> &input_grad,
const Tensor<gpu, 4, DType> &ggrid,
const Tensor<gpu, 4, DType> &output_grad,
const Tensor<gpu, 4, DType> &input_data,
const Tensor<gpu, 4, DType> &grid) {
const Tensor<gpu, 4, DType> &ggrid,
const Tensor<gpu, 4, DType> &output_grad,
const Tensor<gpu, 4, DType> &input_data,
const Tensor<gpu, 4, DType> &grid,
const mxnet::OpReqType data_req,
const mxnet::OpReqType grid_req) {
using namespace mxnet;
DType *g_input = input_grad.dptr_;
DType *grad_grid = ggrid.dptr_;
const DType *grid_src = grid.dptr_;
Expand All @@ -196,8 +209,13 @@ inline void BilinearSamplerBackward(const Tensor<gpu, 4, DType> &input_grad,
dim3 threads_per_block(kMaxThreadsPerBlock);
CheckLaunchParam(num_blocks, threads_per_block, "bilinear sampler backward");
cudaStream_t stream = Stream<gpu>::GetStream(input_grad.stream_);
cuda::BilinearSamplerBackwardKernel<DType> << <num_blocks, threads_per_block, 0, stream >> >(
i_c, i_h, i_w, grad, data, o_n, o_c, o_h, o_w, g_input, grid_src, grad_grid);
MXNET_REQ_TYPE_SWITCH(data_req, Req1, {
MXNET_REQ_TYPE_SWITCH(grid_req, Req2, {
cuda::BilinearSamplerBackwardKernel<DType, Req1, Req2>
<<<num_blocks, threads_per_block, 0, stream >>>(
i_c, i_h, i_w, grad, data, o_n, o_c, o_h, o_w, g_input, grid_src, grad_grid);
});
});
// post kernel check
cudaError err = cudaPeekAtLastError();
CHECK_EQ(err, cudaSuccess) << cudaGetErrorString(err);
Expand Down
27 changes: 27 additions & 0 deletions src/operator/mxnet_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,33 @@ inline int get_num_threads<cpu>(const int N) {
}


/*! \brief operator request type switch */
#define MXNET_REQ_TYPE_SWITCH(req, ReqType, ...) \
switch (req) { \
case kNullOp: \
{ \
const OpReqType ReqType = kNullOp; \
{__VA_ARGS__} \
} \
break; \
case kWriteInplace: \
case kWriteTo: \
{ \
const OpReqType ReqType = kWriteTo; \
{__VA_ARGS__} \
} \
break; \
case kAddTo: \
{ \
const OpReqType ReqType = kAddTo; \
{__VA_ARGS__} \
} \
break; \
default: \
break; \
}


#define MXNET_NDIM_SWITCH(NDim, ndim, ...) \
if (NDim == 0) { \
} else if (NDim == 1) { \
Expand Down
18 changes: 17 additions & 1 deletion tests/python/gpu/test_operator_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1945,7 +1945,7 @@ def test_bilinear_sampler_versions():
exe.arg_dict['data'][:] = test_data
exe.arg_dict['grid'][:] = test_grid
exe.forward(is_train=True)
assert_almost_equal(exe_list[0].outputs[0].asnumpy(), exe.outputs[0].asnumpy(), rtol=1e-3, atol=1e-5)
assert_almost_equal(exe_list[ref_idx].outputs[0].asnumpy(), exe.outputs[0].asnumpy(), rtol=1e-3, atol=1e-5)

out_grad = np.random.uniform(low=-0.01, high=0.01,size=data_shape[:2] + grid_shape[2:]).astype(np.float32)
for exe in exe_list:
Expand Down Expand Up @@ -1975,6 +1975,22 @@ def test_bilinear_sampler_versions():
assert_almost_equal(exe_list[ref_idx].grad_dict['data'].asnumpy(), data_grad + data_initial_grad, rtol=1e-3, atol=1e-5)
assert_almost_equal(exe_list[ref_idx].grad_dict['grid'].asnumpy(), grid_grad + grid_initial_grad, rtol=1e-3, atol=1e-5)

for req_dict in [{'data' : 'null', 'grid' : 'write'}, {'data' : 'write', 'grid' : 'null'}]:
# Mixture of kWriteTo and kNullOp
exe_cpu_mix = sym1.simple_bind(data=data_shape, grid=grid_shape, ctx=mx.cpu(), grad_req=req_dict)
exe_gpu_mix = sym2.simple_bind(data=data_shape, grid=grid_shape, ctx=default_context(), grad_req=req_dict)
exe_cudnn_mix = sym3.simple_bind(data=data_shape, grid=grid_shape, ctx=default_context(), grad_req=req_dict)
exe_list = [exe_cpu_mix, exe_gpu_mix, exe_cudnn_mix]
for exe in exe_list:
exe.arg_dict['data'][:] = test_data
exe.arg_dict['grid'][:] = test_grid
exe.forward(is_train=True)
exe.backward(mx.nd.array(out_grad))
if req_dict['data'] is 'write':
assert_almost_equal(exe.grad_dict['data'].asnumpy(), exe_list[ref_idx].grad_dict['data'].asnumpy(), rtol=1e-3, atol=1e-5)
if req_dict['grid'] is 'write':
assert_almost_equal(exe.grad_dict['grid'].asnumpy(), exe_list[ref_idx].grad_dict['grid'].asnumpy(), rtol=1e-3, atol=1e-5)


def test_context_num_gpus():
# Test that num_gpus reports at least one GPU, as the test is run on a GPU host.
Expand Down

0 comments on commit 3df22a8

Please sign in to comment.