Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
simplify implementation of second order gradient for FC
Browse files Browse the repository at this point in the history
  • Loading branch information
larroy committed Jun 5, 2019
1 parent 745ea66 commit 7458af5
Showing 1 changed file with 1 addition and 17 deletions.
18 changes: 1 addition & 17 deletions src/operator/nn/fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,22 +176,6 @@ struct FullyConnectedGrad {
}
};

std::vector<nnvm::NodeEntry> FullyConnectedBackwardGrad(
const nnvm::NodePtr& n,
const std::vector<nnvm::NodeEntry>& ograds) {
// Note this is not strictly correct but we don't expect inputs to depend on weights at the
// moment. If you find such a case, please contribute a more elaborate implementation.
std::vector<nnvm::NodeEntry> ret;
size_t i = 0;
for (const auto& x : n->inputs) {
std::ostringstream os;
os << n->attrs.name << "_backward_" << i;
ret.emplace_back(MakeNode("zeros_like", os.str(), {x}, nullptr, &n));
++i;
}
return ret;
}

inline static bool FCStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
Expand Down Expand Up @@ -341,7 +325,7 @@ NNVM_REGISTER_OP(_backward_FullyConnected)
.set_attr<nnvm::FInplaceOption>("FInplaceOption", [](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{1, 0}};
})
.set_attr<nnvm::FGradient>("FGradient", FullyConnectedBackwardGrad)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.set_attr<FInferStorageType>("FInferStorageType", BackwardFCStorageType)
.set_attr_parser(ParamParser<FullyConnectedParam>)
#if MXNET_USE_MKLDNN == 1
Expand Down

0 comments on commit 7458af5

Please sign in to comment.