This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-876] make CachedOp a normal operator #11641
Merged
Merged
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
f04686a
extend _CachedOp a regular operator.
zheng-da b7f1066
use default subgraph infer.
zheng-da 33895d1
fix test.
zheng-da 9b0771c
fix compilation error.
zheng-da f253ade
use default subgraph stuff.
zheng-da daf3801
add comments.
zheng-da 15751a4
fix.
zheng-da 7c175c5
use a more general InferStorage.
zheng-da 8fa5ac8
use cachedOp as default subgraph operator.
zheng-da ded2cc2
remove default subgraph op.
zheng-da 7e722a9
fix.
zheng-da f7b63b4
fix.
zheng-da 58d4851
rename.
zheng-da 9f1aa26
add comment.
zheng-da d4c95d3
retrigger
zheng-da fc001dc
add comments.
zheng-da File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,7 @@ | |
#include "../executor/exec_pass.h" | ||
#include "../profiler/profiler.h" | ||
#include "../operator/operator_common.h" | ||
#include "../operator/subgraph/common.h" | ||
|
||
|
||
namespace mxnet { | ||
|
@@ -874,7 +875,6 @@ OpStatePtr CachedOp::Forward( | |
return op_state; | ||
} | ||
|
||
|
||
void CachedOp::DynamicBackward( | ||
const bool retain_graph, | ||
const OpStatePtr& op_state, | ||
|
@@ -1067,34 +1067,153 @@ void CachedOp::Backward( | |
Engine::Get()->set_bulk_size(prev_bulk_size); | ||
} | ||
|
||
bool CachedOp::ForwardStorageType(const nnvm::NodeAttrs& attrs, | ||
const int dev_mask, | ||
DispatchMode* dispatch_mode, | ||
std::vector<int> *in_attrs, | ||
std::vector<int> *out_attrs) { | ||
using namespace imperative; | ||
nnvm::Graph g(fwd_graph_); | ||
const auto& idx = g.indexed_graph(); | ||
const auto &outputs = idx.outputs(); | ||
/* | ||
* This is the operator state of CachedOp when CachedOp is used in the symbol | ||
* executor. This is different from the OpState returned by CachedOp::Forward. | ||
* The main reason why we need this OpState is that CachedOp and the symbol executor | ||
* maintain OpState differently. The symbol executor generates OpState in advance | ||
* while CachedOp generates OpState after Forward is called. We need this data | ||
* structure to keep the OpState generated by CachedOp::Forward and pass it to | ||
* Backward. | ||
*/ | ||
struct CachedOpActualState { | ||
std::shared_ptr<CachedOp> op; | ||
OpStatePtr forward_state; | ||
|
||
// Prepare stypes and contexts based on inputs | ||
StorageTypeVector storage_type_inputs; | ||
storage_type_inputs.reserve(in_attrs->size()); | ||
for (size_t i = 0; i < in_attrs->size(); ++i) { | ||
storage_type_inputs.emplace_back(in_attrs->at(i)); | ||
explicit CachedOpActualState(std::shared_ptr<CachedOp> op) { | ||
this->op = op; | ||
} | ||
exec::DevMaskVector dev_masks(idx.num_nodes(), dev_mask); | ||
}; | ||
|
||
// Forward graph storage type inference | ||
CheckAndInferStorageType(&g, std::move(dev_masks), std::move(storage_type_inputs), true); | ||
// Retrieve result and set outputs | ||
const auto& inferred_stypes = g.GetAttr<StorageTypeVector>("storage_type"); | ||
for (size_t i = 0; i < out_attrs->size(); i++) { | ||
const auto eid = idx.entry_id(outputs[i]); | ||
STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, i, inferred_stypes[eid]); | ||
/* | ||
* This is the forward computation when CachedOp is used as an operator in | ||
* a symbol executor. | ||
*/ | ||
void CachedOpForward(const OpStatePtr& state_ptr, | ||
const OpContext& ctx, | ||
const std::vector<NDArray>& inputs, | ||
const std::vector<OpReqType>& req, | ||
const std::vector<NDArray>& outputs) { | ||
CachedOpActualState &s = state_ptr.get_state<CachedOpActualState>(); | ||
std::vector<NDArray> in_bufs = inputs; | ||
std::vector<NDArray> out_bufs = outputs; | ||
std::vector<NDArray *> in_ptrs(in_bufs.size()); | ||
std::vector<NDArray *> out_ptrs(out_bufs.size()); | ||
for (size_t i = 0; i < in_ptrs.size(); i++) | ||
in_ptrs[i] = &in_bufs[i]; | ||
for (size_t i = 0; i < out_ptrs.size(); i++) | ||
out_ptrs[i] = &out_bufs[i]; | ||
|
||
// Set is_recording correct for the imperative executor. | ||
bool orig_is_record; | ||
if (ctx.need_grad) | ||
orig_is_record = Imperative::Get()->set_is_recording(true); | ||
else | ||
orig_is_record = Imperative::Get()->is_recording(); | ||
// Set is_training correct for the imperative executor. | ||
bool orig_is_train; | ||
if (ctx.is_train) | ||
orig_is_train = Imperative::Get()->set_is_training(true); | ||
else | ||
orig_is_train = Imperative::Get()->is_training(); | ||
s.forward_state = s.op->Forward(nullptr, in_ptrs, out_ptrs); | ||
Imperative::Get()->set_is_training(orig_is_train); | ||
Imperative::Get()->set_is_recording(orig_is_record); | ||
// The arrays in out_ptrs may be changed by CachedOp. | ||
// If it is, we need to copy data back. | ||
for (size_t i = 0; i < out_bufs.size(); i++) | ||
if (!out_bufs[i].IsSame(outputs[i])) | ||
CopyFromTo(out_bufs[i], outputs[i]); | ||
} | ||
|
||
/* | ||
* This is the backward computation when CachedOp is used as an operator in | ||
* a symbol executor. | ||
*/ | ||
void CachedOpBackward(const OpStatePtr& state_ptr, | ||
const OpContext& ctx, | ||
const std::vector<NDArray>& inputs, | ||
const std::vector<OpReqType>& req, | ||
const std::vector<NDArray>& outputs) { | ||
using namespace nnvm; | ||
using namespace imperative; | ||
CachedOpActualState &s = state_ptr.get_state<CachedOpActualState>(); | ||
std::vector<NDArray> in_bufs = inputs; | ||
std::vector<NDArray> out_bufs = outputs; | ||
std::vector<NDArray *> in_ptrs; | ||
std::vector<NDArray *> out_ptrs; | ||
CHECK_EQ(s.op->num_backward_inputs(), inputs.size()); | ||
in_ptrs.reserve(s.op->num_backward_inputs()); | ||
out_ptrs.reserve(s.op->num_inputs()); | ||
|
||
const std::vector<bool> &save_inputs = s.op->save_inputs(); | ||
const std::vector<bool> &save_outputs = s.op->save_outputs(); | ||
size_t bwd_in_dep = s.op->num_inputs(); | ||
size_t bwd_out_dep = s.op->num_outputs(); | ||
CHECK(s.op->num_backward_inputs() > bwd_in_dep + bwd_out_dep); | ||
size_t bwd_ograd_dep = s.op->num_backward_inputs() - bwd_in_dep - bwd_out_dep; | ||
|
||
// Find inputs, outputs and ograds | ||
auto ograds_begin = in_bufs.begin(); | ||
auto ograds_end = in_bufs.begin() + bwd_ograd_dep; | ||
auto in_begin = ograds_end; | ||
auto in_end = in_begin + bwd_in_dep; | ||
auto out_begin = in_end; | ||
auto out_end = in_bufs.end(); | ||
|
||
for (auto it = ograds_begin; it != ograds_end; it++) | ||
in_ptrs.push_back(&(*it)); | ||
|
||
CHECK_EQ(save_inputs.size(), in_end - in_begin); | ||
CHECK_EQ(s.op->num_outputs(), out_end - out_begin); | ||
for (auto it = in_begin; it != in_end; it++) { | ||
auto i = it - in_begin; | ||
if (save_inputs[i]) | ||
in_ptrs.push_back(&(*it)); | ||
} | ||
DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx); | ||
return true; | ||
for (auto it = out_begin; it != out_end; it++) { | ||
auto i = it - out_begin; | ||
if (save_outputs[i]) | ||
in_ptrs.push_back(&(*it)); | ||
} | ||
CHECK_EQ(in_ptrs.size(), s.op->num_backward_inputs()); | ||
for (size_t i = 0; i < out_bufs.size(); i++) | ||
out_ptrs.push_back(&out_bufs[i]); | ||
CHECK_EQ(out_ptrs.size(), s.op->num_backward_outputs()); | ||
// Set is_training correct for the imperative executor. | ||
bool orig_is_train; | ||
if (ctx.is_train) | ||
orig_is_train = Imperative::Get()->set_is_training(true); | ||
else | ||
orig_is_train = Imperative::Get()->is_training(); | ||
// TODO(zhengda) CachedOp supports recording computation when running | ||
// the backward path. This is necessary if we want to support the second-order | ||
// differentiation. However, MXNet operator doesn't have an interface to | ||
// pass a flag to determine whether to record computation inside an operator. | ||
// Let's use false here for now and design a solution when the second-order | ||
// differentiation is supported. | ||
s.op->Backward(false, s.forward_state, in_ptrs, req, out_ptrs); | ||
Imperative::Get()->set_is_training(orig_is_train); | ||
|
||
// Clean up what we recorded. | ||
s.forward_state.reset(); | ||
|
||
// The arrays in out_ptrs may be changed by CachedOp. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why would it be changed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for updating the comments |
||
// If it is, we need to copy data back. | ||
// For example, when the inputs and outputs share the same NDArrays, | ||
// the outputs will be replaced by inputs. | ||
// https://github.com/apache/incubator-mxnet/blob/v1.2.0/src/imperative/cached_op.cc#L385 | ||
for (size_t i = 0; i < out_bufs.size(); i++) | ||
if (!out_bufs[i].IsSame(outputs[i])) | ||
CopyFromTo(out_bufs[i], outputs[i]); | ||
} | ||
|
||
OpStatePtr CreateCachedOpState(const NodeAttrs& attrs, | ||
Context ctx, | ||
const std::vector<TShape>& in_shapes, | ||
const std::vector<int>& in_types) { | ||
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed); | ||
return OpStatePtr::Create<CachedOpActualState>(op); | ||
} | ||
|
||
bool CachedOp::BackwardStorageType(const nnvm::NodeAttrs& attrs, | ||
|
@@ -1143,6 +1262,32 @@ bool CachedOp::BackwardStorageType(const nnvm::NodeAttrs& attrs, | |
return true; | ||
} | ||
|
||
void CachedOpParamParser(nnvm::NodeAttrs* attrs) { | ||
CachedOpConfig param; | ||
try { | ||
param.Init(attrs->dict); | ||
} catch (const dmlc::ParamError& e) { | ||
std::ostringstream os; | ||
os << e.what(); | ||
os << ", in operator " << attrs->op->name << "(" | ||
<< "name=\"" << attrs->name << "\""; | ||
for (const auto& k : attrs->dict) { | ||
os << ", " << k.first << "=\"" << k.second << "\""; | ||
} | ||
os << ")"; | ||
throw dmlc::ParamError(os.str()); | ||
} | ||
if (!param.subgraph.empty()) { | ||
nnvm::Graph g = nnvm::pass::LoadJSON(param.subgraph); | ||
CHECK(!g.outputs.empty()); | ||
nnvm::Symbol sym; | ||
sym.outputs = g.outputs; | ||
std::vector<std::pair<std::string, std::string> > flags; | ||
for (auto it = attrs->dict.begin(); it != attrs->dict.end(); it++) | ||
flags.emplace_back(it->first, it->second); | ||
attrs->parsed = CachedOpPtr(new CachedOp(sym, flags)); | ||
} | ||
} | ||
|
||
NNVM_REGISTER_OP(_CachedOp) | ||
.set_num_inputs([](const NodeAttrs& attrs) { | ||
|
@@ -1153,19 +1298,62 @@ NNVM_REGISTER_OP(_CachedOp) | |
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed); | ||
return op->num_outputs(); | ||
}) | ||
.set_attr<FInferStorageType>("FInferStorageType", [](const nnvm::NodeAttrs& attrs, | ||
const int dev_mask, | ||
DispatchMode* dispatch_mode, | ||
std::vector<int> *in_attrs, | ||
std::vector<int> *out_attrs) { | ||
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed); | ||
return op->ForwardStorageType(attrs, dev_mask, dispatch_mode, in_attrs, out_attrs); | ||
}) | ||
.set_attr_parser(CachedOpParamParser) | ||
.set_attr<nnvm::FGradient>("FGradient", | ||
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) { | ||
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(n->attrs.parsed); | ||
return op->Gradient(n, ograds); | ||
}); | ||
}) | ||
.set_attr<nnvm::FListInputNames>("FListInputNames", | ||
[](const nnvm::NodeAttrs& attrs) { | ||
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed); | ||
return op->ListForwardInputNames(); | ||
}) | ||
.set_attr<nnvm::FListOutputNames>("FListOutputNames", | ||
[](const nnvm::NodeAttrs& attrs) { | ||
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed); | ||
return op->ListForwardOutputNames(); | ||
}) | ||
.set_attr<FCreateOpState>("FCreateOpState", CreateCachedOpState) | ||
.set_attr<nnvm::FInferShape>("FInferShape", | ||
[](const nnvm::NodeAttrs& attrs, | ||
std::vector<TShape> *in_shapes, | ||
std::vector<TShape> *out_shapes) { | ||
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed); | ||
return op::DefaultSubgraphOpShapeHelper(op->GetForwardSym(), in_shapes, out_shapes); | ||
}) | ||
.set_attr<nnvm::FInferType>("FInferType", | ||
[](const nnvm::NodeAttrs& attrs, | ||
std::vector<int> *in_types, | ||
std::vector<int> *out_types) { | ||
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed); | ||
return op::DefaultSubgraphOpTypeHelper(op->GetForwardSym(), in_types, out_types); | ||
}) | ||
.set_attr<FInferStorageType>("FInferStorageType", | ||
[](const nnvm::NodeAttrs& attrs, | ||
const int dev_mask, | ||
DispatchMode* dispatch_mode, | ||
std::vector<int>* in_stypes, | ||
std::vector<int>* out_stypes) { | ||
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed); | ||
return op::DefaultSubgraphOpStorageTypeHelper(op->GetForwardSym(), | ||
dev_mask, dispatch_mode, | ||
in_stypes, out_stypes); | ||
}) | ||
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", CachedOpForward) | ||
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", CachedOpForward) | ||
.set_attr<nnvm::FMutateInputs>("FMutateInputs", | ||
[](const nnvm::NodeAttrs& attrs) { | ||
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed); | ||
return op::DefaultSubgraphOpMutableInputsHelper(op->GetForwardSym()); | ||
}) | ||
.set_attr<FResourceRequest>("FResourceRequest", | ||
[](const nnvm::NodeAttrs& attrs) { | ||
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed); | ||
return op::DefaultSubgraphOpResourceRequestHelper(op->GetForwardSym()); | ||
}) | ||
.set_attr<FExecType>("FExecType", op::DefaultSubgraphOpExecType) | ||
.add_argument("data", "NDArray-or-Symbol[]", "input data list"); | ||
|
||
NNVM_REGISTER_OP(_backward_CachedOp) | ||
.set_num_inputs([](const NodeAttrs& attrs){ | ||
|
@@ -1184,6 +1372,9 @@ NNVM_REGISTER_OP(_backward_CachedOp) | |
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed); | ||
return op->BackwardStorageType(attrs, dev_mask, dispatch_mode, in_attrs, out_attrs); | ||
}) | ||
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", CachedOpBackward) | ||
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", CachedOpBackward) | ||
.set_attr<FExecType>("FExecType", op::DefaultSubgraphOpExecType) | ||
.set_attr<bool>("TIsLayerOpBackward", true) | ||
.set_attr<bool>("TIsBackward", true); | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please elaborate on the necessity of adding this data structure in the description.