Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix request
Browse files Browse the repository at this point in the history
  • Loading branch information
arcadiaphy committed Dec 16, 2019
1 parent 4da7b06 commit 19c98b6
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
20 changes: 16 additions & 4 deletions src/operator/nn/im2col-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,21 @@ void Im2colCompute(const nnvm::NodeAttrs& attrs,
Tensor<xpu, 3, DType> col = outputs[0].get_with_shape<xpu, 3, DType>(
Shape3(col_shape[0], col_shape[1], col_shape[2]), s);

for (index_t n = 0; n < num; ++n) {
im2col(s, im[n].dptr_, im_shape, col_buffer_shape,
param.kernel, param.pad, param.stride, param.dilate, col[n].dptr_);
if (req[0] == kNullOp) return;
if (req[0] != kAddTo) {
for (index_t n = 0; n < num; ++n) {
im2col(s, im[n].dptr_, im_shape, col_buffer_shape,
param.kernel, param.pad, param.stride, param.dilate, col[n].dptr_);
}
} else {
Tensor<xpu, 2, DType> tcol = ctx.requested[0]
.get_space_typed<xpu, 2, DType>(Shape2(col_shape[1], col_shape[2]), s);
for (index_t n = 0; n < num; ++n) {
im2col(s, im[n].dptr_, im_shape, col_buffer_shape,
param.kernel, param.pad, param.stride, param.dilate, tcol.dptr_);
Tensor<xpu, 2, DType> ocol = col[n];
ocol += tcol;
}
}
});
}
Expand Down Expand Up @@ -188,7 +200,7 @@ void Col2imCompute(const nnvm::NodeAttrs& attrs,
for (index_t n = 0; n < num; ++n) {
col2im(s, col[n].dptr_, im_shape, col_buffer_shape,
param.kernel, param.pad, param.stride, param.dilate,
im[n].dptr_, kWriteTo);
im[n].dptr_, req[0]);
}
});
}
Expand Down
4 changes: 4 additions & 0 deletions src/operator/nn/im2col.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ Only 1-D, 2-D and 3-D of spatial dimension is supported in this operator.
TYPE_ASSIGN_CHECK(*out_type, 0, dtype);
return true;
})
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", Im2colCompute<cpu>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_im2col"})
.add_argument("data", "NDArray-or-Symbol", "Input array to extract sliding blocks.")
Expand Down

0 comments on commit 19c98b6

Please sign in to comment.