Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
gyshi committed Sep 16, 2019
1 parent f14f20e commit e1065ab
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 11 deletions.
20 changes: 10 additions & 10 deletions src/operator/numpy/np_matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -398,18 +398,18 @@ NNVM_REGISTER_OP(_np_roll)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<mxnet::FCompute>("FCompute<cpu>", NumpyRollCompute<cpu>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
const NumpyRollParam& param = nnvm::get<NumpyRollParam>(n->attrs.parsed);
std::ostringstream os1;
os1 << param.shift;
std::ostringstream os2;
os2 << param.axis;
return MakeNonlossGradNode("_np_roll", n, ograds, {},
{{"shift", os1.str()}, {"axis", os2.str()}});
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
const NumpyRollParam& param = nnvm::get<NumpyRollParam>(n->attrs.parsed);
std::ostringstream os1;
os1 << param.shift;
std::ostringstream os2;
os2 << param.axis;
return MakeNonlossGradNode("_np_roll", n, ograds, {},
{{"shift", os1.str()}, {"axis", os2.str()}});
})
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
[](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.add_argument("data", "NDArray-or-Symbol", "Input ndarray")
.add_arguments(NumpyRollParam::__FIELDS__());
Expand Down
1 change: 0 additions & 1 deletion tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1999,7 +1999,6 @@ def hybrid_forward(self, F, x):
assert same(mx_out.asnumpy(), np_out)
assert same(x.grad.shape, x.shape)
assert same(x.grad.asnumpy(), _np.ones(shape))

# test imperativen
np_out = _np.roll(x.asnumpy(), shift=shift, axis=axis)
mx_out = np.roll(x, shift=shift, axis=axis)
Expand Down

0 comments on commit e1065ab

Please sign in to comment.