Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add support for more req patterns for bilinear sampler backward
Browse files Browse the repository at this point in the history
  • Loading branch information
Hao Jin committed Aug 28, 2018
1 parent e2a3eef commit 4da6c55
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 93 deletions.
11 changes: 4 additions & 7 deletions src/operator/bilinear_sampler-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,19 +92,16 @@ class BilinearSamplerOp : public Operator {
Tensor<xpu, 4, DType> gdata = in_grad[bs::kData].get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> ggrid = in_grad[bs::kGrid].get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> grad = out_grad[bs::kOut].get<xpu, 4, DType>(s);
if (req[bs::kData] != kNullOp && req[bs::kGrid] != kNullOp) {
if (req[bs::kData] == kNullOp && req[bs::kGrid] == kNullOp) {
return;
} else {
if (req[bs::kData] == kWriteTo) {
gdata = scalar<DType>(0.0f);
}
if (req[bs::kGrid] == kWriteTo) {
ggrid = scalar<DType>(0.0f);
}
BilinearSamplerBackward(gdata, ggrid, grad, data, grid);
} else if (req[bs::kData] == kNullOp && req[bs::kGrid] == kNullOp) {
return;
} else {
LOG(FATAL) << "Have not implemented the data req combinations! gdata_req="
<< req[bs::kData] << " ggrid_req=" << req[bs::kGrid];
BilinearSamplerBackward(gdata, ggrid, grad, data, grid, req[bs::kData], req[bs::kGrid]);
}
}

Expand Down
90 changes: 48 additions & 42 deletions src/operator/bilinear_sampler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,12 @@ inline void BilinearSamplerForward(const Tensor<cpu, 4, DType> &output,

template<typename DType>
inline void BilinearSamplerBackward(const Tensor<cpu, 4, DType> &gdata,
const Tensor<cpu, 4, DType> &ggrid,
const Tensor<cpu, 4, DType> &output_grad,
const Tensor<cpu, 4, DType> &input_data,
const Tensor<cpu, 4, DType> &grid) {
const Tensor<cpu, 4, DType> &ggrid,
const Tensor<cpu, 4, DType> &output_grad,
const Tensor<cpu, 4, DType> &input_data,
const Tensor<cpu, 4, DType> &grid,
const mxnet::OpReqType data_req,
const mxnet::OpReqType grid_req) {
DType *g_input = gdata.dptr_;
DType *grad_grid = ggrid.dptr_;
const DType *grid_src = grid.dptr_;
Expand All @@ -102,46 +104,50 @@ inline void BilinearSamplerBackward(const Tensor<cpu, 4, DType> &gdata,
int top_left_x = static_cast<int>(floor(x_real));
DType top_left_y_w = 1.0 - (y_real - top_left_y);
DType top_left_x_w = 1.0 - (x_real - top_left_x);
for (index_t c = 0; c < static_cast<index_t>(o_c); ++c) {
index_t grad_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w;
int data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w
+ top_left_x;
// calc 4 vertex value in input data
DType top_left_v = 0;
DType top_right_v = 0;
DType bottom_left_v = 0;
DType bottom_right_v = 0;
// calc input grad
if (between(top_left_x, 0, i_w-1) && between(top_left_y, 0, i_h-1)) {
*(g_input + data_index) += *(grad + grad_index) * top_left_y_w * top_left_x_w;
top_left_v = *(data + data_index);
if (data_req != mxnet::kNullOp) {
for (index_t c = 0; c < static_cast<index_t>(o_c); ++c) {
index_t grad_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w;
int data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w
+ top_left_x;
// calc 4 vertex value in input data
DType top_left_v = 0;
DType top_right_v = 0;
DType bottom_left_v = 0;
DType bottom_right_v = 0;
// calc input grad
if (between(top_left_x, 0, i_w-1) && between(top_left_y, 0, i_h-1)) {
*(g_input + data_index) += *(grad + grad_index) * top_left_y_w * top_left_x_w;
top_left_v = *(data + data_index);
}
if (between(top_left_x+1, 0, i_w-1) && between(top_left_y, 0, i_h-1)) {
*(g_input + data_index + 1) += *(grad + grad_index) * top_left_y_w
* (1.0 - top_left_x_w);
top_right_v = *(data + data_index + 1);
}
if (between(top_left_x, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) {
*(g_input + data_index+ i_w) += *(grad + grad_index) * (1.0 - top_left_y_w)
* top_left_x_w;
bottom_left_v = *(data + data_index + i_w);
}
if (between(top_left_x+1, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) {
*(g_input + data_index+ i_w + 1) += *(grad + grad_index) * (1.0 - top_left_y_w)
* (1.0 - top_left_x_w);
bottom_right_v = *(data + data_index + i_w + 1);
}
// calc weight grad of top_left_w, then multiple -1 is the grad of grid_src
top_left_y_gw -= *(grad + grad_index) * (top_right_v - bottom_right_v +
(top_left_v - top_right_v - bottom_left_v + bottom_right_v)
* top_left_x_w);
top_left_x_gw -= *(grad + grad_index) * (bottom_left_v - bottom_right_v +
(top_left_v - top_right_v - bottom_left_v + bottom_right_v)
* top_left_y_w);
}
if (between(top_left_x+1, 0, i_w-1) && between(top_left_y, 0, i_h-1)) {
*(g_input + data_index + 1) += *(grad + grad_index) * top_left_y_w
* (1.0 - top_left_x_w);
top_right_v = *(data + data_index + 1);
}
if (between(top_left_x, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) {
*(g_input + data_index+ i_w) += *(grad + grad_index) * (1.0 - top_left_y_w)
* top_left_x_w;
bottom_left_v = *(data + data_index + i_w);
}
if (between(top_left_x+1, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) {
*(g_input + data_index+ i_w + 1) += *(grad + grad_index) * (1.0 - top_left_y_w)
* (1.0 - top_left_x_w);
bottom_right_v = *(data + data_index + i_w + 1);
}
// calc weight grad of top_left_w, then multiple -1 is the grad of grid_src
top_left_y_gw -= *(grad + grad_index) * (top_right_v - bottom_right_v +
(top_left_v - top_right_v - bottom_left_v + bottom_right_v)
* top_left_x_w);
top_left_x_gw -= *(grad + grad_index) * (bottom_left_v - bottom_right_v +
(top_left_v - top_right_v - bottom_left_v + bottom_right_v)
* top_left_y_w);
}
// calc grad of grid
*(grad_grid + grid_src_index + o_h * o_w) += top_left_y_gw * (i_h - 1) / 2;
*(grad_grid + grid_src_index) += top_left_x_gw * (i_w - 1) / 2;
if (grid_req != mxnet::kNullOp) {
// calc grad of grid
*(grad_grid + grid_src_index + o_h * o_w) += top_left_y_gw * (i_h - 1) / 2;
*(grad_grid + grid_src_index) += top_left_x_gw * (i_w - 1) / 2;
}
}
}
}
Expand Down
100 changes: 56 additions & 44 deletions src/operator/bilinear_sampler.cu
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ __global__ void BilinearSamplerForwardKernel(const int i_c, const int i_h,
}
}

template<typename DType>
template<typename DType, int Req1, int Req2>
__global__ void BilinearSamplerBackwardKernel(const int i_c, const int i_h,
const int i_w, const DType* grad,
const DType* data, const int o_n,
Expand All @@ -104,45 +104,49 @@ __global__ void BilinearSamplerBackwardKernel(const int i_c, const int i_h,
int top_left_x = static_cast<int>(floor(x_real));
DType top_left_y_w = 1.0 - (y_real - top_left_y);
DType top_left_x_w = 1.0 - (x_real - top_left_x);
for (index_t c = 0; c < o_c; ++c) {
index_t grad_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w;
int data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w + top_left_x;
// calc 4 vertex value in input data
DType top_left_v = 0;
DType top_right_v = 0;
DType bottom_left_v = 0;
DType bottom_right_v = 0;
// calc input grad
if (between(top_left_x, 0, i_w-1) && between(top_left_y, 0, i_h-1)) {
atomicAdd(&g_input[data_index], *(grad + grad_index) * top_left_y_w * top_left_x_w);
top_left_v = *(data + data_index);
}
if (between(top_left_x+1, 0, i_w-1) && between(top_left_y, 0, i_h-1)) {
atomicAdd(&g_input[data_index + 1], *(grad + grad_index) * top_left_y_w
* (1.0 - top_left_x_w));
top_right_v = *(data + data_index + 1);
}
if (between(top_left_x, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) {
atomicAdd(&g_input[data_index+ i_w], *(grad + grad_index) * (1.0 - top_left_y_w)
* top_left_x_w);
bottom_left_v = *(data + data_index + i_w);
if (Req1 != mxnet::kNullOp) {
for (index_t c = 0; c < o_c; ++c) {
index_t grad_index = n * o_c * o_h * o_w + c * o_h * o_w + h * o_w + w;
int data_index = n * i_c * i_h * i_w + c * i_h * i_w + top_left_y * i_w + top_left_x;
// calc 4 vertex value in input data
DType top_left_v = 0;
DType top_right_v = 0;
DType bottom_left_v = 0;
DType bottom_right_v = 0;
// calc input grad
if (between(top_left_x, 0, i_w-1) && between(top_left_y, 0, i_h-1)) {
atomicAdd(&g_input[data_index], *(grad + grad_index) * top_left_y_w * top_left_x_w);
top_left_v = *(data + data_index);
}
if (between(top_left_x+1, 0, i_w-1) && between(top_left_y, 0, i_h-1)) {
atomicAdd(&g_input[data_index + 1], *(grad + grad_index) * top_left_y_w
* (1.0 - top_left_x_w));
top_right_v = *(data + data_index + 1);
}
if (between(top_left_x, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) {
atomicAdd(&g_input[data_index+ i_w], *(grad + grad_index) * (1.0 - top_left_y_w)
* top_left_x_w);
bottom_left_v = *(data + data_index + i_w);
}
if (between(top_left_x+1, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) {
atomicAdd(&g_input[data_index+ i_w + 1], *(grad + grad_index) * (1.0 - top_left_y_w)
* (1.0 - top_left_x_w));
bottom_right_v = *(data + data_index + i_w + 1);
}
// calc weight grad of top_left_w, then multiple -1 is the grad of grid_src
top_left_y_gw -= *(grad + grad_index) * (top_right_v - bottom_right_v +
(top_left_v - top_right_v - bottom_left_v + bottom_right_v)
* top_left_x_w);
top_left_x_gw -= *(grad + grad_index) * (bottom_left_v - bottom_right_v +
(top_left_v - top_right_v - bottom_left_v + bottom_right_v)
* top_left_y_w);
}
if (between(top_left_x+1, 0, i_w-1) && between(top_left_y+1, 0, i_h-1)) {
atomicAdd(&g_input[data_index+ i_w + 1], *(grad + grad_index) * (1.0 - top_left_y_w)
* (1.0 - top_left_x_w));
bottom_right_v = *(data + data_index + i_w + 1);
}
// calc weight grad of top_left_w, then multiple -1 is the grad of grid_src
top_left_y_gw -= *(grad + grad_index) * (top_right_v - bottom_right_v +
(top_left_v - top_right_v - bottom_left_v + bottom_right_v)
* top_left_x_w);
top_left_x_gw -= *(grad + grad_index) * (bottom_left_v - bottom_right_v +
(top_left_v - top_right_v - bottom_left_v + bottom_right_v)
* top_left_y_w);
}
// calc grad of grid
*(grad_grid + grid_src_index + o_h * o_w) += top_left_y_gw * (i_h - 1) / 2;
*(grad_grid + grid_src_index) += top_left_x_gw * (i_w - 1) / 2;
if (Req2 != mxnet::kNullOp) {
// calc grad of grid
*(grad_grid + grid_src_index + o_h * o_w) += top_left_y_gw * (i_h - 1) / 2;
*(grad_grid + grid_src_index) += top_left_x_gw * (i_w - 1) / 2;
}
}
}
} // namespace cuda
Expand Down Expand Up @@ -174,10 +178,13 @@ inline void BilinearSamplerForward(const Tensor<gpu, 4, DType> &output,

template<typename DType>
inline void BilinearSamplerBackward(const Tensor<gpu, 4, DType> &input_grad,
const Tensor<gpu, 4, DType> &ggrid,
const Tensor<gpu, 4, DType> &output_grad,
const Tensor<gpu, 4, DType> &input_data,
const Tensor<gpu, 4, DType> &grid) {
const Tensor<gpu, 4, DType> &ggrid,
const Tensor<gpu, 4, DType> &output_grad,
const Tensor<gpu, 4, DType> &input_data,
const Tensor<gpu, 4, DType> &grid,
const mxnet::OpReqType data_req,
const mxnet::OpReqType grid_req) {
using namespace mxnet;
DType *g_input = input_grad.dptr_;
DType *grad_grid = ggrid.dptr_;
const DType *grid_src = grid.dptr_;
Expand All @@ -196,8 +203,13 @@ inline void BilinearSamplerBackward(const Tensor<gpu, 4, DType> &input_grad,
dim3 threads_per_block(kMaxThreadsPerBlock);
CheckLaunchParam(num_blocks, threads_per_block, "bilinear sampler backward");
cudaStream_t stream = Stream<gpu>::GetStream(input_grad.stream_);
cuda::BilinearSamplerBackwardKernel<DType> << <num_blocks, threads_per_block, 0, stream >> >(
i_c, i_h, i_w, grad, data, o_n, o_c, o_h, o_w, g_input, grid_src, grad_grid);
MXNET_ASSIGN_REQ_SWITCH(data_req, Req1, {
MXNET_ASSIGN_REQ_SWITCH(grid_req, Req2, {
cuda::BilinearSamplerBackwardKernel<DType, Req1, Req2>
<<<num_blocks, threads_per_block, 0, stream >>>(
i_c, i_h, i_w, grad, data, o_n, o_c, o_h, o_w, g_input, grid_src, grad_grid);
});
});
// post kernel check
cudaError err = cudaPeekAtLastError();
CHECK_EQ(err, cudaSuccess) << cudaGetErrorString(err);
Expand Down

0 comments on commit 4da6c55

Please sign in to comment.