diff --git a/src/operator/nn/im2col-inl.h b/src/operator/nn/im2col-inl.h index c7f9bde93862..b5caa035f911 100644 --- a/src/operator/nn/im2col-inl.h +++ b/src/operator/nn/im2col-inl.h @@ -87,9 +87,21 @@ void Im2colCompute(const nnvm::NodeAttrs& attrs, Tensor col = outputs[0].get_with_shape( Shape3(col_shape[0], col_shape[1], col_shape[2]), s); - for (index_t n = 0; n < num; ++n) { - im2col(s, im[n].dptr_, im_shape, col_buffer_shape, - param.kernel, param.pad, param.stride, param.dilate, col[n].dptr_); + if (req[0] == kNullOp) return; + if (req[0] != kAddTo) { + for (index_t n = 0; n < num; ++n) { + im2col(s, im[n].dptr_, im_shape, col_buffer_shape, + param.kernel, param.pad, param.stride, param.dilate, col[n].dptr_); + } + } else { + Tensor tcol = ctx.requested[0] + .get_space_typed(Shape2(col_shape[1], col_shape[2]), s); + for (index_t n = 0; n < num; ++n) { + im2col(s, im[n].dptr_, im_shape, col_buffer_shape, + param.kernel, param.pad, param.stride, param.dilate, tcol.dptr_); + Tensor ocol = col[n]; + ocol += tcol; + } } }); } @@ -188,7 +200,7 @@ void Col2imCompute(const nnvm::NodeAttrs& attrs, for (index_t n = 0; n < num; ++n) { col2im(s, col[n].dptr_, im_shape, col_buffer_shape, param.kernel, param.pad, param.stride, param.dilate, - im[n].dptr_, kWriteTo); + im[n].dptr_, req[0]); } }); } diff --git a/src/operator/nn/im2col.cc b/src/operator/nn/im2col.cc index 4aa4eaaf77d0..ae493f1bc594 100644 --- a/src/operator/nn/im2col.cc +++ b/src/operator/nn/im2col.cc @@ -149,6 +149,10 @@ Only 1-D, 2-D and 3-D of spatial dimension is supported in this operator. TYPE_ASSIGN_CHECK(*out_type, 0, dtype); return true; }) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; +}) .set_attr("FCompute", Im2colCompute) .set_attr("FGradient", ElemwiseGradUseNone{"_backward_im2col"}) .add_argument("data", "NDArray-or-Symbol", "Input array to extract sliding blocks.")