Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-810] Add support for more req patterns for bilinear sampler backward #12386

Merged
merged 1 commit into from
Sep 11, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions src/operator/bilinear_sampler-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,19 +95,16 @@ class BilinearSamplerOp : public Operator {
Tensor<xpu, 4, DType> gdata = in_grad[bs::kData].get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> ggrid = in_grad[bs::kGrid].get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> grad = out_grad[bs::kOut].get<xpu, 4, DType>(s);
if (req[bs::kData] != kNullOp && req[bs::kGrid] != kNullOp) {
if (req[bs::kData] == kNullOp && req[bs::kGrid] == kNullOp) {
return;
} else {
if (req[bs::kData] == kWriteTo) {
gdata = scalar<DType>(0.0f);
}
if (req[bs::kGrid] == kWriteTo) {
ggrid = scalar<DType>(0.0f);
}
BilinearSamplerBackward(gdata, ggrid, grad, data, grid);
} else if (req[bs::kData] == kNullOp && req[bs::kGrid] == kNullOp) {
return;
} else {
LOG(FATAL) << "Have not implemented the data req combinations! gdata_req="
<< req[bs::kData] << " ggrid_req=" << req[bs::kGrid];
BilinearSamplerBackward(gdata, ggrid, grad, data, grid, req[bs::kData], req[bs::kGrid]);
}
}

Expand Down
43 changes: 27 additions & 16 deletions src/operator/bilinear_sampler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,12 @@ inline void BilinearSamplerForward(const Tensor<cpu, 4, DType> &output,

template<typename DType>
inline void BilinearSamplerBackward(const Tensor<cpu, 4, DType> &gdata,
const Tensor<cpu, 4, DType> &ggrid,
const Tensor<cpu, 4, DType> &output_grad,
const Tensor<cpu, 4, DType> &input_data,
const Tensor<cpu, 4, DType> &grid) {
const Tensor<cpu, 4, DType> &ggrid,
const Tensor<cpu, 4, DType> &output_grad,
const Tensor<cpu, 4, DType> &input_data,
const Tensor<cpu, 4, DType> &grid,
const mxnet::OpReqType data_req,
const mxnet::OpReqType grid_req) {
DType *g_input = gdata.dptr_;
DType *grad_grid = ggrid.dptr_;
const DType *grid_src = grid.dptr_;
Expand All @@ -104,31 +106,38 @@ inline void BilinearSamplerBackward(const Tensor<cpu, 4, DType> &gdata,
DType top_left_x_w = 1.0 - (x_real - top_left_x);
for (index_t c = 0; c < static_cast<index_t>(o_c); ++c) {
index_t grad_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w;
int data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w
+ top_left_x;
int data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w + top_left_x;
// calc 4 vertex value in input data
DType top_left_v = 0;
DType top_right_v = 0;
DType bottom_left_v = 0;
DType bottom_right_v = 0;
// calc input grad
if (between(top_left_x, 0, i_w-1) && between(top_left_y, 0, i_h-1)) {
*(g_input + data_index) += *(grad + grad_index) * top_left_y_w * top_left_x_w;
if (data_req != mxnet::kNullOp) {
*(g_input + data_index) += *(grad + grad_index) * top_left_y_w * top_left_x_w;
}
top_left_v = *(data + data_index);
}
if (between(top_left_x+1, 0, i_w-1) && between(top_left_y, 0, i_h-1)) {
*(g_input + data_index + 1) += *(grad + grad_index) * top_left_y_w
* (1.0 - top_left_x_w);
if (data_req != mxnet::kNullOp) {
*(g_input + data_index + 1) +=
*(grad + grad_index) * top_left_y_w * (1.0 - top_left_x_w);
}
top_right_v = *(data + data_index + 1);
}
if (between(top_left_x, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) {
*(g_input + data_index+ i_w) += *(grad + grad_index) * (1.0 - top_left_y_w)
* top_left_x_w;
if (data_req != mxnet::kNullOp) {
*(g_input + data_index+ i_w) +=
*(grad + grad_index) * (1.0 - top_left_y_w) * top_left_x_w;
}
bottom_left_v = *(data + data_index + i_w);
}
if (between(top_left_x+1, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) {
*(g_input + data_index+ i_w + 1) += *(grad + grad_index) * (1.0 - top_left_y_w)
* (1.0 - top_left_x_w);
if (data_req != mxnet::kNullOp) {
*(g_input + data_index+ i_w + 1) +=
*(grad + grad_index) * (1.0 - top_left_y_w) * (1.0 - top_left_x_w);
}
bottom_right_v = *(data + data_index + i_w + 1);
}
// calc weight grad of top_left_w, then multiple -1 is the grad of grid_src
Expand All @@ -139,9 +148,11 @@ inline void BilinearSamplerBackward(const Tensor<cpu, 4, DType> &gdata,
(top_left_v - top_right_v - bottom_left_v + bottom_right_v)
* top_left_y_w);
}
// calc grad of grid
*(grad_grid + grid_src_index + o_h * o_w) += top_left_y_gw * (i_h - 1) / 2;
*(grad_grid + grid_src_index) += top_left_x_gw * (i_w - 1) / 2;
if (grid_req != mxnet::kNullOp) {
// calc grad of grid
*(grad_grid + grid_src_index + o_h * o_w) += top_left_y_gw * (i_h - 1) / 2;
*(grad_grid + grid_src_index) += top_left_x_gw * (i_w - 1) / 2;
}
}
}
}
Expand Down
52 changes: 35 additions & 17 deletions src/operator/bilinear_sampler.cu
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ __global__ void BilinearSamplerForwardKernel(const int i_c, const int i_h,
}
}

template<typename DType>
template<typename DType, int Req1, int Req2>
__global__ void BilinearSamplerBackwardKernel(const int i_c, const int i_h,
const int i_w, const DType* grad,
const DType* data, const int o_n,
Expand Down Expand Up @@ -114,22 +114,30 @@ __global__ void BilinearSamplerBackwardKernel(const int i_c, const int i_h,
DType bottom_right_v = 0;
// calc input grad
if (between(top_left_x, 0, i_w-1) && between(top_left_y, 0, i_h-1)) {
atomicAdd(&g_input[data_index], *(grad + grad_index) * top_left_y_w * top_left_x_w);
if (Req1 != mxnet::kNullOp) {
atomicAdd(&g_input[data_index], *(grad + grad_index) * top_left_y_w * top_left_x_w);
}
top_left_v = *(data + data_index);
}
if (between(top_left_x+1, 0, i_w-1) && between(top_left_y, 0, i_h-1)) {
atomicAdd(&g_input[data_index + 1], *(grad + grad_index) * top_left_y_w
* (1.0 - top_left_x_w));
if (Req1 != mxnet::kNullOp) {
atomicAdd(&g_input[data_index + 1],
*(grad + grad_index) * top_left_y_w * (1.0 - top_left_x_w));
}
top_right_v = *(data + data_index + 1);
}
if (between(top_left_x, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) {
atomicAdd(&g_input[data_index+ i_w], *(grad + grad_index) * (1.0 - top_left_y_w)
* top_left_x_w);
if (Req1 != mxnet::kNullOp) {
atomicAdd(&g_input[data_index+ i_w],
*(grad + grad_index) * (1.0 - top_left_y_w) * top_left_x_w);
}
bottom_left_v = *(data + data_index + i_w);
}
if (between(top_left_x+1, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) {
atomicAdd(&g_input[data_index+ i_w + 1], *(grad + grad_index) * (1.0 - top_left_y_w)
* (1.0 - top_left_x_w));
if (Req1 != mxnet::kNullOp) {
atomicAdd(&g_input[data_index+ i_w + 1],
*(grad + grad_index) * (1.0 - top_left_y_w) * (1.0 - top_left_x_w));
}
bottom_right_v = *(data + data_index + i_w + 1);
}
// calc weight grad of top_left_w, then multiple -1 is the grad of grid_src
Expand All @@ -140,9 +148,11 @@ __global__ void BilinearSamplerBackwardKernel(const int i_c, const int i_h,
(top_left_v - top_right_v - bottom_left_v + bottom_right_v)
* top_left_y_w);
}
// calc grad of grid
*(grad_grid + grid_src_index + o_h * o_w) += top_left_y_gw * (i_h - 1) / 2;
*(grad_grid + grid_src_index) += top_left_x_gw * (i_w - 1) / 2;
if (Req2 != mxnet::kNullOp) {
// calc grad of grid
*(grad_grid + grid_src_index + o_h * o_w) += top_left_y_gw * (i_h - 1) / 2;
*(grad_grid + grid_src_index) += top_left_x_gw * (i_w - 1) / 2;
}
}
}
} // namespace cuda
Expand Down Expand Up @@ -174,10 +184,13 @@ inline void BilinearSamplerForward(const Tensor<gpu, 4, DType> &output,

template<typename DType>
inline void BilinearSamplerBackward(const Tensor<gpu, 4, DType> &input_grad,
const Tensor<gpu, 4, DType> &ggrid,
const Tensor<gpu, 4, DType> &output_grad,
const Tensor<gpu, 4, DType> &input_data,
const Tensor<gpu, 4, DType> &grid) {
const Tensor<gpu, 4, DType> &ggrid,
const Tensor<gpu, 4, DType> &output_grad,
const Tensor<gpu, 4, DType> &input_data,
const Tensor<gpu, 4, DType> &grid,
const mxnet::OpReqType data_req,
const mxnet::OpReqType grid_req) {
using namespace mxnet;
DType *g_input = input_grad.dptr_;
DType *grad_grid = ggrid.dptr_;
const DType *grid_src = grid.dptr_;
Expand All @@ -196,8 +209,13 @@ inline void BilinearSamplerBackward(const Tensor<gpu, 4, DType> &input_grad,
dim3 threads_per_block(kMaxThreadsPerBlock);
CheckLaunchParam(num_blocks, threads_per_block, "bilinear sampler backward");
cudaStream_t stream = Stream<gpu>::GetStream(input_grad.stream_);
cuda::BilinearSamplerBackwardKernel<DType> << <num_blocks, threads_per_block, 0, stream >> >(
i_c, i_h, i_w, grad, data, o_n, o_c, o_h, o_w, g_input, grid_src, grad_grid);
MXNET_REQ_TYPE_SWITCH(data_req, Req1, {
MXNET_REQ_TYPE_SWITCH(grid_req, Req2, {
cuda::BilinearSamplerBackwardKernel<DType, Req1, Req2>
<<<num_blocks, threads_per_block, 0, stream >>>(
i_c, i_h, i_w, grad, data, o_n, o_c, o_h, o_w, g_input, grid_src, grad_grid);
});
});
// post kernel check
cudaError err = cudaPeekAtLastError();
CHECK_EQ(err, cudaSuccess) << cudaGetErrorString(err);
Expand Down
27 changes: 27 additions & 0 deletions src/operator/mxnet_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,33 @@ inline int get_num_threads<cpu>(const int N) {
}


/*! \brief operator request type switch */
#define MXNET_REQ_TYPE_SWITCH(req, ReqType, ...) \
haojin2 marked this conversation as resolved.
Show resolved Hide resolved
switch (req) { \
case kNullOp: \
{ \
const OpReqType ReqType = kNullOp; \
{__VA_ARGS__} \
} \
break; \
case kWriteInplace: \
case kWriteTo: \
{ \
const OpReqType ReqType = kWriteTo; \
{__VA_ARGS__} \
} \
break; \
case kAddTo: \
{ \
const OpReqType ReqType = kAddTo; \
{__VA_ARGS__} \
} \
break; \
default: \
break; \
}


#define MXNET_NDIM_SWITCH(NDim, ndim, ...) \
if (NDim == 0) { \
} else if (NDim == 1) { \
Expand Down
18 changes: 17 additions & 1 deletion tests/python/gpu/test_operator_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1945,7 +1945,7 @@ def test_bilinear_sampler_versions():
exe.arg_dict['data'][:] = test_data
exe.arg_dict['grid'][:] = test_grid
exe.forward(is_train=True)
assert_almost_equal(exe_list[0].outputs[0].asnumpy(), exe.outputs[0].asnumpy(), rtol=1e-3, atol=1e-5)
assert_almost_equal(exe_list[ref_idx].outputs[0].asnumpy(), exe.outputs[0].asnumpy(), rtol=1e-3, atol=1e-5)

out_grad = np.random.uniform(low=-0.01, high=0.01,size=data_shape[:2] + grid_shape[2:]).astype(np.float32)
for exe in exe_list:
Expand Down Expand Up @@ -1975,6 +1975,22 @@ def test_bilinear_sampler_versions():
assert_almost_equal(exe_list[ref_idx].grad_dict['data'].asnumpy(), data_grad + data_initial_grad, rtol=1e-3, atol=1e-5)
assert_almost_equal(exe_list[ref_idx].grad_dict['grid'].asnumpy(), grid_grad + grid_initial_grad, rtol=1e-3, atol=1e-5)

for req_dict in [{'data' : 'null', 'grid' : 'write'}, {'data' : 'write', 'grid' : 'null'}]:
# Mixture of kWriteTo and kNullOp
exe_cpu_mix = sym1.simple_bind(data=data_shape, grid=grid_shape, ctx=mx.cpu(), grad_req=req_dict)
exe_gpu_mix = sym2.simple_bind(data=data_shape, grid=grid_shape, ctx=default_context(), grad_req=req_dict)
exe_cudnn_mix = sym3.simple_bind(data=data_shape, grid=grid_shape, ctx=default_context(), grad_req=req_dict)
exe_list = [exe_cpu_mix, exe_gpu_mix, exe_cudnn_mix]
for exe in exe_list:
exe.arg_dict['data'][:] = test_data
exe.arg_dict['grid'][:] = test_grid
exe.forward(is_train=True)
exe.backward(mx.nd.array(out_grad))
if req_dict['data'] is 'write':
assert_almost_equal(exe.grad_dict['data'].asnumpy(), exe_list[ref_idx].grad_dict['data'].asnumpy(), rtol=1e-3, atol=1e-5)
if req_dict['grid'] is 'write':
assert_almost_equal(exe.grad_dict['grid'].asnumpy(), exe_list[ref_idx].grad_dict['grid'].asnumpy(), rtol=1e-3, atol=1e-5)


def test_context_num_gpus():
# Test that num_gpus reports at least one GPU, as the test is run on a GPU host.
Expand Down