Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
change transfer gradient
Browse files Browse the repository at this point in the history
  • Loading branch information
arcadiaphy committed May 16, 2019
1 parent 8d16468 commit 889fc6e
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
1 change: 0 additions & 1 deletion src/operator/tensor/la_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,6 @@ struct syevd_backward {
struct inverse_backward {
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 3, DType>& dA,
const Tensor<xpu, 3, DType>& B,
const Tensor<xpu, 3, DType>& A,
const Tensor<xpu, 3, DType>& dB,
const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
Expand Down
6 changes: 3 additions & 3 deletions src/operator/tensor/la_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -926,18 +926,18 @@ Examples::
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs)
{ return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; })
.set_attr<FCompute>("FCompute<cpu>", LaOpForward<cpu, 2, 2, 1, 1, inverse>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseInOut{"_backward_linalg_inverse"})
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_linalg_inverse"})
.add_argument("A", "NDArray-or-Symbol", "Tensor of square matrix");

NNVM_REGISTER_OP(_backward_linalg_inverse)
.set_num_inputs(3)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr<nnvm::FInplaceOption>("FInplaceOption", [](const NodeAttrs& attrs)
{ return std::vector<std::pair<int, int> >{{0, 0}}; })
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs)
{ return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; })
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", LaOpBackward<cpu, 2, 2, 3, 1, inverse_backward>);
.set_attr<FCompute>("FCompute<cpu>", LaOpBackward<cpu, 2, 2, 2, 1, inverse_backward>);

} // namespace op
} // namespace mxnet

0 comments on commit 889fc6e

Please sign in to comment.