-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-876] make CachedOp a normal operator #11641
Changes from 15 commits
f04686a
b7f1066
33895d1
9b0771c
f253ade
daf3801
15751a4
7c175c5
8fa5ac8
ded2cc2
7e722a9
f7b63b4
58d4851
9f1aa26
d4c95d3
fc001dc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,7 @@ | |
#include "../executor/exec_pass.h" | ||
#include "../profiler/profiler.h" | ||
#include "../operator/operator_common.h" | ||
#include "../operator/subgraph/common.h" | ||
|
||
|
||
namespace mxnet { | ||
|
@@ -874,7 +875,6 @@ OpStatePtr CachedOp::Forward( | |
return op_state; | ||
} | ||
|
||
|
||
void CachedOp::DynamicBackward( | ||
const bool retain_graph, | ||
const OpStatePtr& op_state, | ||
|
@@ -1067,34 +1067,145 @@ void CachedOp::Backward( | |
Engine::Get()->set_bulk_size(prev_bulk_size); | ||
} | ||
|
||
bool CachedOp::ForwardStorageType(const nnvm::NodeAttrs& attrs, | ||
const int dev_mask, | ||
DispatchMode* dispatch_mode, | ||
std::vector<int> *in_attrs, | ||
std::vector<int> *out_attrs) { | ||
using namespace imperative; | ||
nnvm::Graph g(fwd_graph_); | ||
const auto& idx = g.indexed_graph(); | ||
const auto &outputs = idx.outputs(); | ||
/* | ||
* This is the operator state of CachedOp when CachedOp is used in the symbol | ||
* executor. This is different from the OpState returned by CachedOp::Forward. | ||
* The main reason why we need this OpState is that CachedOp and the symbol executor | ||
* maintain OpState differently. The symbol executor generates OpState in advance | ||
* while CachedOp generates OpState after Forward is called. We need this data | ||
* structure to keep the OpState generated by CachedOp::Forward and pass it to | ||
* Backward. | ||
*/ | ||
struct CachedOpActualState { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing class doc |
||
std::shared_ptr<CachedOp> op; | ||
OpStatePtr forward_state; | ||
|
||
// Prepare stypes and contexts based on inputs | ||
StorageTypeVector storage_type_inputs; | ||
storage_type_inputs.reserve(in_attrs->size()); | ||
for (size_t i = 0; i < in_attrs->size(); ++i) { | ||
storage_type_inputs.emplace_back(in_attrs->at(i)); | ||
explicit CachedOpActualState(std::shared_ptr<CachedOp> op) { | ||
this->op = op; | ||
} | ||
exec::DevMaskVector dev_masks(idx.num_nodes(), dev_mask); | ||
}; | ||
|
||
// Forward graph storage type inference | ||
CheckAndInferStorageType(&g, std::move(dev_masks), std::move(storage_type_inputs), true); | ||
// Retrieve result and set outputs | ||
const auto& inferred_stypes = g.GetAttr<StorageTypeVector>("storage_type"); | ||
for (size_t i = 0; i < out_attrs->size(); i++) { | ||
const auto eid = idx.entry_id(outputs[i]); | ||
STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, i, inferred_stypes[eid]); | ||
/* | ||
* This is the forward computation when CachedOp is used as an operator in | ||
* a symbol executor. | ||
*/ | ||
void CachedOpForward(const OpStatePtr& state_ptr, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing a short documentation stating what's the intention and how it works |
||
const OpContext& ctx, | ||
const std::vector<NDArray>& inputs, | ||
const std::vector<OpReqType>& req, | ||
const std::vector<NDArray>& outputs) { | ||
CachedOpActualState &s = state_ptr.get_state<CachedOpActualState>(); | ||
std::vector<NDArray> in_bufs = inputs; | ||
std::vector<NDArray> out_bufs = outputs; | ||
std::vector<NDArray *> in_ptrs(in_bufs.size()); | ||
std::vector<NDArray *> out_ptrs(out_bufs.size()); | ||
for (size_t i = 0; i < in_ptrs.size(); i++) | ||
in_ptrs[i] = &in_bufs[i]; | ||
for (size_t i = 0; i < out_ptrs.size(); i++) | ||
out_ptrs[i] = &out_bufs[i]; | ||
|
||
// Set is_recording correct for the imperative executor. | ||
bool orig_is_record; | ||
if (ctx.need_grad) | ||
orig_is_record = Imperative::Get()->set_is_recording(true); | ||
else | ||
orig_is_record = Imperative::Get()->is_recording(); | ||
// Set is_training correct for the imperative executor. | ||
bool orig_is_train; | ||
if (ctx.is_train) | ||
orig_is_train = Imperative::Get()->set_is_training(true); | ||
else | ||
orig_is_train = Imperative::Get()->is_training(); | ||
s.forward_state = s.op->Forward(nullptr, in_ptrs, out_ptrs); | ||
Imperative::Get()->set_is_training(orig_is_train); | ||
Imperative::Get()->set_is_recording(orig_is_record); | ||
// The arrays in out_ptrs may be changed by CachedOp. | ||
// If it is, we need to copy data back. | ||
for (size_t i = 0; i < out_bufs.size(); i++) | ||
if (!out_bufs[i].IsSame(outputs[i])) | ||
CopyFromTo(out_bufs[i], outputs[i]); | ||
} | ||
|
||
/* | ||
* This is the backward computation when CachedOp is used as an operator in | ||
* a symbol executor. | ||
*/ | ||
void CachedOpBackward(const OpStatePtr& state_ptr, | ||
const OpContext& ctx, | ||
const std::vector<NDArray>& inputs, | ||
const std::vector<OpReqType>& req, | ||
const std::vector<NDArray>& outputs) { | ||
using namespace nnvm; | ||
using namespace imperative; | ||
CachedOpActualState &s = state_ptr.get_state<CachedOpActualState>(); | ||
std::vector<NDArray> in_bufs = inputs; | ||
std::vector<NDArray> out_bufs = outputs; | ||
std::vector<NDArray *> in_ptrs; | ||
std::vector<NDArray *> out_ptrs; | ||
CHECK_EQ(s.op->num_backward_inputs(), inputs.size()); | ||
in_ptrs.reserve(s.op->num_backward_inputs()); | ||
out_ptrs.reserve(s.op->num_inputs()); | ||
|
||
const std::vector<bool> &save_inputs = s.op->save_inputs(); | ||
const std::vector<bool> &save_outputs = s.op->save_outputs(); | ||
size_t bwd_in_dep = s.op->num_inputs(); | ||
size_t bwd_out_dep = s.op->num_outputs(); | ||
CHECK(s.op->num_backward_inputs() > bwd_in_dep + bwd_out_dep); | ||
size_t bwd_ograd_dep = s.op->num_backward_inputs() - bwd_in_dep - bwd_out_dep; | ||
|
||
// Find inputs, outputs and ograds | ||
auto ograds_begin = in_bufs.begin(); | ||
auto ograds_end = in_bufs.begin() + bwd_ograd_dep; | ||
auto in_begin = ograds_end; | ||
auto in_end = in_begin + bwd_in_dep; | ||
auto out_begin = in_end; | ||
auto out_end = in_bufs.end(); | ||
|
||
for (auto it = ograds_begin; it != ograds_end; it++) | ||
in_ptrs.push_back(&(*it)); | ||
|
||
CHECK_EQ(save_inputs.size(), in_end - in_begin); | ||
CHECK_EQ(s.op->num_outputs(), out_end - out_begin); | ||
for (auto it = in_begin; it != in_end; it++) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ++it is potentially faster There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. really? Where is this documented? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @zheng-da it's a well known thing for old C++ farts. It's in reference C++ books like http://www.cppstdlib.com/ or Stroustrup. https://stackoverflow.com/questions/1077026/incrementing-iterators-it-more-efficient-than-it In most cases probably doesn't make a difference, specially for simple iterators where the iterator is just a pointer. That's why I said is potentially faster. It's more like a good idiomatic practice to always use preincrement. https://stackoverflow.com/questions/1077026/incrementing-iterators-it-more-efficient-than-it |
||
auto i = it - in_begin; | ||
if (save_inputs[i]) | ||
in_ptrs.push_back(&(*it)); | ||
} | ||
DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx); | ||
return true; | ||
for (auto it = out_begin; it != out_end; it++) { | ||
auto i = it - out_begin; | ||
if (save_outputs[i]) | ||
in_ptrs.push_back(&(*it)); | ||
} | ||
CHECK_EQ(in_ptrs.size(), s.op->num_backward_inputs()); | ||
for (size_t i = 0; i < out_bufs.size(); i++) | ||
out_ptrs.push_back(&out_bufs[i]); | ||
CHECK_EQ(out_ptrs.size(), s.op->num_backward_outputs()); | ||
// Set is_training correct for the imperative executor. | ||
bool orig_is_train; | ||
if (ctx.is_train) | ||
orig_is_train = Imperative::Get()->set_is_training(true); | ||
else | ||
orig_is_train = Imperative::Get()->is_training(); | ||
// TODO(zhengda) is it right to use false here? | ||
s.op->Backward(false, s.forward_state, in_ptrs, req, out_ptrs); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please add more comment on |
||
Imperative::Get()->set_is_training(orig_is_train); | ||
|
||
// Clean up what we recorded. | ||
s.forward_state.reset(); | ||
|
||
// The arrays in out_ptrs may be changed by CachedOp. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why would it be changed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for updating the comments |
||
// If it is, we need to copy data back. | ||
for (size_t i = 0; i < out_bufs.size(); i++) | ||
if (!out_bufs[i].IsSame(outputs[i])) | ||
CopyFromTo(out_bufs[i], outputs[i]); | ||
} | ||
|
||
OpStatePtr CreateCachedOpState(const NodeAttrs& attrs, | ||
Context ctx, | ||
const std::vector<TShape>& in_shapes, | ||
const std::vector<int>& in_types) { | ||
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed); | ||
return OpStatePtr::Create<CachedOpActualState>(op); | ||
} | ||
|
||
bool CachedOp::BackwardStorageType(const nnvm::NodeAttrs& attrs, | ||
|
@@ -1143,6 +1254,32 @@ bool CachedOp::BackwardStorageType(const nnvm::NodeAttrs& attrs, | |
return true; | ||
} | ||
|
||
void CachedOpParamParser(nnvm::NodeAttrs* attrs) { | ||
CachedOpConfig param; | ||
try { | ||
param.Init(attrs->dict); | ||
} catch (const dmlc::ParamError& e) { | ||
std::ostringstream os; | ||
os << e.what(); | ||
os << ", in operator " << attrs->op->name << "(" | ||
<< "name=\"" << attrs->name << "\""; | ||
for (const auto& k : attrs->dict) { | ||
os << ", " << k.first << "=\"" << k.second << "\""; | ||
} | ||
os << ")"; | ||
throw dmlc::ParamError(os.str()); | ||
} | ||
if (!param.subgraph.empty()) { | ||
nnvm::Graph g = nnvm::pass::LoadJSON(param.subgraph); | ||
CHECK(!g.outputs.empty()); | ||
nnvm::Symbol sym; | ||
sym.outputs = g.outputs; | ||
std::vector<std::pair<std::string, std::string> > flags; | ||
for (auto it = attrs->dict.begin(); it != attrs->dict.end(); it++) | ||
flags.emplace_back(it->first, it->second); | ||
attrs->parsed = CachedOpPtr(new CachedOp(sym, flags)); | ||
} | ||
} | ||
|
||
NNVM_REGISTER_OP(_CachedOp) | ||
.set_num_inputs([](const NodeAttrs& attrs) { | ||
|
@@ -1153,19 +1290,62 @@ NNVM_REGISTER_OP(_CachedOp) | |
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed); | ||
return op->num_outputs(); | ||
}) | ||
.set_attr<FInferStorageType>("FInferStorageType", [](const nnvm::NodeAttrs& attrs, | ||
const int dev_mask, | ||
DispatchMode* dispatch_mode, | ||
std::vector<int> *in_attrs, | ||
std::vector<int> *out_attrs) { | ||
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed); | ||
return op->ForwardStorageType(attrs, dev_mask, dispatch_mode, in_attrs, out_attrs); | ||
}) | ||
.set_attr_parser(CachedOpParamParser) | ||
.set_attr<nnvm::FGradient>("FGradient", | ||
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) { | ||
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(n->attrs.parsed); | ||
return op->Gradient(n, ograds); | ||
}); | ||
}) | ||
.set_attr<nnvm::FListInputNames>("FListInputNames", | ||
[](const nnvm::NodeAttrs& attrs) { | ||
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed); | ||
return op->ListForwardInputNames(); | ||
}) | ||
.set_attr<nnvm::FListOutputNames>("FListOutputNames", | ||
[](const nnvm::NodeAttrs& attrs) { | ||
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed); | ||
return op->ListForwardOutputNames(); | ||
}) | ||
.set_attr<FCreateOpState>("FCreateOpState", CreateCachedOpState) | ||
.set_attr<nnvm::FInferShape>("FInferShape", | ||
[](const nnvm::NodeAttrs& attrs, | ||
std::vector<TShape> *in_shapes, | ||
std::vector<TShape> *out_shapes) { | ||
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed); | ||
return op::DefaultSubgraphOpShapeHelper(op->GetForwardSym(), in_shapes, out_shapes); | ||
}) | ||
.set_attr<nnvm::FInferType>("FInferType", | ||
[](const nnvm::NodeAttrs& attrs, | ||
std::vector<int> *in_types, | ||
std::vector<int> *out_types) { | ||
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed); | ||
return op::DefaultSubgraphOpTypeHelper(op->GetForwardSym(), in_types, out_types); | ||
}) | ||
.set_attr<FInferStorageType>("FInferStorageType", | ||
[](const nnvm::NodeAttrs& attrs, | ||
const int dev_mask, | ||
DispatchMode* dispatch_mode, | ||
std::vector<int>* in_stypes, | ||
std::vector<int>* out_stypes) { | ||
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed); | ||
return op::DefaultSubgraphOpStorageTypeHelper(op->GetForwardSym(), | ||
dev_mask, dispatch_mode, | ||
in_stypes, out_stypes); | ||
}) | ||
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", CachedOpForward) | ||
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", CachedOpForward) | ||
.set_attr<nnvm::FMutateInputs>("FMutateInputs", | ||
[](const nnvm::NodeAttrs& attrs) { | ||
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed); | ||
return op::DefaultSubgraphOpMutableInputsHelper(op->GetForwardSym()); | ||
}) | ||
.set_attr<FResourceRequest>("FResourceRequest", | ||
[](const nnvm::NodeAttrs& attrs) { | ||
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed); | ||
return op::DefaultSubgraphOpResourceRequestHelper(op->GetForwardSym()); | ||
}) | ||
.set_attr<FExecType>("FExecType", op::DefaultSubgraphOpExecType) | ||
.add_argument("data", "NDArray-or-Symbol[]", "input data list"); | ||
|
||
NNVM_REGISTER_OP(_backward_CachedOp) | ||
.set_num_inputs([](const NodeAttrs& attrs){ | ||
|
@@ -1184,6 +1364,9 @@ NNVM_REGISTER_OP(_backward_CachedOp) | |
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed); | ||
return op->BackwardStorageType(attrs, dev_mask, dispatch_mode, in_attrs, out_attrs); | ||
}) | ||
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", CachedOpBackward) | ||
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", CachedOpBackward) | ||
.set_attr<FExecType>("FExecType", op::DefaultSubgraphOpExecType) | ||
.set_attr<bool>("TIsLayerOpBackward", true) | ||
.set_attr<bool>("TIsBackward", true); | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please elaborate on the necessity of adding this data structure in the description.