Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add backward to fully connected. (_backward_FullyConnected)
Browse files Browse the repository at this point in the history
  • Loading branch information
larroy committed Apr 23, 2019
1 parent a386644 commit a6ca88a
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/operator/nn/fully_connected-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ void FullyConnectedGradCompute(const nnvm::NodeAttrs& attrs,
}
}


} // namespace op
} // namespace mxnet
namespace std {
Expand Down
16 changes: 15 additions & 1 deletion src/operator/nn/fully_connected.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/*
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
Expand Down Expand Up @@ -165,6 +165,7 @@ static bool FullyConnectedType(const nnvm::NodeAttrs& attrs,
attrs, in_type, out_type, -1);
}


struct FullyConnectedGrad {
const char *op_name;
std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n,
Expand All @@ -176,6 +177,16 @@ struct FullyConnectedGrad {
}
};


std::vector<nnvm::NodeEntry> FullyConnectedBackwardGrad(const nnvm::NodePtr& n,
const std::vector<nnvm::NodeEntry>& ograds) {
auto zero_node = MakeNode("zeros_like", n->attrs.name + "_backward", {n->inputs[0]}, nullptr, &n);
std::vector<nnvm::NodeEntry> ret;
ret.emplace_back(nnvm::NodeEntry{zero_node, 0, 0});
return ret;
}


inline static bool FCStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
Expand Down Expand Up @@ -310,6 +321,7 @@ If ``no_bias`` is set to be true, then the ``bias`` term is ignored.
.add_argument("bias", "NDArray-or-Symbol", "Bias parameter.")
.add_arguments(FullyConnectedParam::__FIELDS__());


NNVM_REGISTER_OP(_backward_FullyConnected)
.set_num_inputs(3)
.set_num_outputs([](const NodeAttrs& attrs) {
Expand All @@ -325,6 +337,7 @@ NNVM_REGISTER_OP(_backward_FullyConnected)
.set_attr<nnvm::FInplaceOption>("FInplaceOption", [](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{1, 0}};
})
.set_attr<nnvm::FGradient>("FGradient", FullyConnectedBackwardGrad)
.set_attr<FInferStorageType>("FInferStorageType", BackwardFCStorageType)
.set_attr_parser(ParamParser<FullyConnectedParam>)
#if MXNET_USE_MKLDNN == 1
Expand All @@ -333,5 +346,6 @@ NNVM_REGISTER_OP(_backward_FullyConnected)
#endif
.set_attr<FCompute>("FCompute<cpu>", FullyConnectedGradCompute<cpu>);


} // namespace op
} // namespace mxnet

0 comments on commit a6ca88a

Please sign in to comment.