Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-876] make CachedOp a normal operator (#11641)
Browse files Browse the repository at this point in the history
* extend _CachedOp a regular operator.

* use default subgraph infer.

* fix test.

* fix compilation error.

* use default subgraph stuff.

* add comments.

* fix.

* use a more general InferStorage.

* use cachedOp as default subgraph operator.

* remove default subgraph op.

* fix.

* fix.

* rename.

* add comment.

* retrigger

* add comments.
  • Loading branch information
zheng-da authored and eric-haibin-lin committed Sep 23, 2018
1 parent de01c46 commit 3caf2ca
Show file tree
Hide file tree
Showing 8 changed files with 453 additions and 224 deletions.
259 changes: 225 additions & 34 deletions src/imperative/cached_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "../executor/exec_pass.h"
#include "../profiler/profiler.h"
#include "../operator/operator_common.h"
#include "../operator/subgraph/common.h"


namespace mxnet {
Expand Down Expand Up @@ -874,7 +875,6 @@ OpStatePtr CachedOp::Forward(
return op_state;
}


void CachedOp::DynamicBackward(
const bool retain_graph,
const OpStatePtr& op_state,
Expand Down Expand Up @@ -1067,34 +1067,153 @@ void CachedOp::Backward(
Engine::Get()->set_bulk_size(prev_bulk_size);
}

bool CachedOp::ForwardStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
using namespace imperative;
nnvm::Graph g(fwd_graph_);
const auto& idx = g.indexed_graph();
const auto &outputs = idx.outputs();
/*
* This is the operator state of CachedOp when CachedOp is used in the symbol
* executor. This is different from the OpState returned by CachedOp::Forward.
* The main reason why we need this OpState is that CachedOp and the symbol executor
* maintain OpState differently. The symbol executor generates OpState in advance
* while CachedOp generates OpState after Forward is called. We need this data
* structure to keep the OpState generated by CachedOp::Forward and pass it to
* Backward.
*/
struct CachedOpActualState {
std::shared_ptr<CachedOp> op;
OpStatePtr forward_state;

// Prepare stypes and contexts based on inputs
StorageTypeVector storage_type_inputs;
storage_type_inputs.reserve(in_attrs->size());
for (size_t i = 0; i < in_attrs->size(); ++i) {
storage_type_inputs.emplace_back(in_attrs->at(i));
explicit CachedOpActualState(std::shared_ptr<CachedOp> op) {
this->op = op;
}
exec::DevMaskVector dev_masks(idx.num_nodes(), dev_mask);
};

// Forward graph storage type inference
CheckAndInferStorageType(&g, std::move(dev_masks), std::move(storage_type_inputs), true);
// Retrieve result and set outputs
const auto& inferred_stypes = g.GetAttr<StorageTypeVector>("storage_type");
for (size_t i = 0; i < out_attrs->size(); i++) {
const auto eid = idx.entry_id(outputs[i]);
STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, i, inferred_stypes[eid]);
/*
* This is the forward computation when CachedOp is used as an operator in
* a symbol executor.
*/
void CachedOpForward(const OpStatePtr& state_ptr,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CachedOpActualState &s = state_ptr.get_state<CachedOpActualState>();
std::vector<NDArray> in_bufs = inputs;
std::vector<NDArray> out_bufs = outputs;
std::vector<NDArray *> in_ptrs(in_bufs.size());
std::vector<NDArray *> out_ptrs(out_bufs.size());
for (size_t i = 0; i < in_ptrs.size(); i++)
in_ptrs[i] = &in_bufs[i];
for (size_t i = 0; i < out_ptrs.size(); i++)
out_ptrs[i] = &out_bufs[i];

// Set is_recording correct for the imperative executor.
bool orig_is_record;
if (ctx.need_grad)
orig_is_record = Imperative::Get()->set_is_recording(true);
else
orig_is_record = Imperative::Get()->is_recording();
// Set is_training correct for the imperative executor.
bool orig_is_train;
if (ctx.is_train)
orig_is_train = Imperative::Get()->set_is_training(true);
else
orig_is_train = Imperative::Get()->is_training();
s.forward_state = s.op->Forward(nullptr, in_ptrs, out_ptrs);
Imperative::Get()->set_is_training(orig_is_train);
Imperative::Get()->set_is_recording(orig_is_record);
// The arrays in out_ptrs may be changed by CachedOp.
// If it is, we need to copy data back.
for (size_t i = 0; i < out_bufs.size(); i++)
if (!out_bufs[i].IsSame(outputs[i]))
CopyFromTo(out_bufs[i], outputs[i]);
}

/*
* This is the backward computation when CachedOp is used as an operator in
* a symbol executor.
*/
void CachedOpBackward(const OpStatePtr& state_ptr,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
using namespace nnvm;
using namespace imperative;
CachedOpActualState &s = state_ptr.get_state<CachedOpActualState>();
std::vector<NDArray> in_bufs = inputs;
std::vector<NDArray> out_bufs = outputs;
std::vector<NDArray *> in_ptrs;
std::vector<NDArray *> out_ptrs;
CHECK_EQ(s.op->num_backward_inputs(), inputs.size());
in_ptrs.reserve(s.op->num_backward_inputs());
out_ptrs.reserve(s.op->num_inputs());

const std::vector<bool> &save_inputs = s.op->save_inputs();
const std::vector<bool> &save_outputs = s.op->save_outputs();
size_t bwd_in_dep = s.op->num_inputs();
size_t bwd_out_dep = s.op->num_outputs();
CHECK(s.op->num_backward_inputs() > bwd_in_dep + bwd_out_dep);
size_t bwd_ograd_dep = s.op->num_backward_inputs() - bwd_in_dep - bwd_out_dep;

// Find inputs, outputs and ograds
auto ograds_begin = in_bufs.begin();
auto ograds_end = in_bufs.begin() + bwd_ograd_dep;
auto in_begin = ograds_end;
auto in_end = in_begin + bwd_in_dep;
auto out_begin = in_end;
auto out_end = in_bufs.end();

for (auto it = ograds_begin; it != ograds_end; it++)
in_ptrs.push_back(&(*it));

CHECK_EQ(save_inputs.size(), in_end - in_begin);
CHECK_EQ(s.op->num_outputs(), out_end - out_begin);
for (auto it = in_begin; it != in_end; it++) {
auto i = it - in_begin;
if (save_inputs[i])
in_ptrs.push_back(&(*it));
}
DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
return true;
for (auto it = out_begin; it != out_end; it++) {
auto i = it - out_begin;
if (save_outputs[i])
in_ptrs.push_back(&(*it));
}
CHECK_EQ(in_ptrs.size(), s.op->num_backward_inputs());
for (size_t i = 0; i < out_bufs.size(); i++)
out_ptrs.push_back(&out_bufs[i]);
CHECK_EQ(out_ptrs.size(), s.op->num_backward_outputs());
// Set is_training correct for the imperative executor.
bool orig_is_train;
if (ctx.is_train)
orig_is_train = Imperative::Get()->set_is_training(true);
else
orig_is_train = Imperative::Get()->is_training();
// TODO(zhengda) CachedOp supports recording computation when running
// the backward path. This is necessary if we want to support the second-order
// differentiation. However, MXNet operator doesn't have an interface to
// pass a flag to determine whether to record computation inside an operator.
// Let's use false here for now and design a solution when the second-order
// differentiation is supported.
s.op->Backward(false, s.forward_state, in_ptrs, req, out_ptrs);
Imperative::Get()->set_is_training(orig_is_train);

// Clean up what we recorded.
s.forward_state.reset();

// The arrays in out_ptrs may be changed by CachedOp.
// If it is, we need to copy data back.
// For example, when the inputs and outputs share the same NDArrays,
// the outputs will be replaced by inputs.
// https://github.com/apache/incubator-mxnet/blob/v1.2.0/src/imperative/cached_op.cc#L385
for (size_t i = 0; i < out_bufs.size(); i++)
if (!out_bufs[i].IsSame(outputs[i]))
CopyFromTo(out_bufs[i], outputs[i]);
}

OpStatePtr CreateCachedOpState(const NodeAttrs& attrs,
Context ctx,
const std::vector<TShape>& in_shapes,
const std::vector<int>& in_types) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return OpStatePtr::Create<CachedOpActualState>(op);
}

bool CachedOp::BackwardStorageType(const nnvm::NodeAttrs& attrs,
Expand Down Expand Up @@ -1143,6 +1262,32 @@ bool CachedOp::BackwardStorageType(const nnvm::NodeAttrs& attrs,
return true;
}

void CachedOpParamParser(nnvm::NodeAttrs* attrs) {
CachedOpConfig param;
try {
param.Init(attrs->dict);
} catch (const dmlc::ParamError& e) {
std::ostringstream os;
os << e.what();
os << ", in operator " << attrs->op->name << "("
<< "name=\"" << attrs->name << "\"";
for (const auto& k : attrs->dict) {
os << ", " << k.first << "=\"" << k.second << "\"";
}
os << ")";
throw dmlc::ParamError(os.str());
}
if (!param.subgraph.empty()) {
nnvm::Graph g = nnvm::pass::LoadJSON(param.subgraph);
CHECK(!g.outputs.empty());
nnvm::Symbol sym;
sym.outputs = g.outputs;
std::vector<std::pair<std::string, std::string> > flags;
for (auto it = attrs->dict.begin(); it != attrs->dict.end(); it++)
flags.emplace_back(it->first, it->second);
attrs->parsed = CachedOpPtr(new CachedOp(sym, flags));
}
}

NNVM_REGISTER_OP(_CachedOp)
.set_num_inputs([](const NodeAttrs& attrs) {
Expand All @@ -1153,19 +1298,62 @@ NNVM_REGISTER_OP(_CachedOp)
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op->num_outputs();
})
.set_attr<FInferStorageType>("FInferStorageType", [](const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op->ForwardStorageType(attrs, dev_mask, dispatch_mode, in_attrs, out_attrs);
})
.set_attr_parser(CachedOpParamParser)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(n->attrs.parsed);
return op->Gradient(n, ograds);
});
})
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const nnvm::NodeAttrs& attrs) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op->ListForwardInputNames();
})
.set_attr<nnvm::FListOutputNames>("FListOutputNames",
[](const nnvm::NodeAttrs& attrs) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op->ListForwardOutputNames();
})
.set_attr<FCreateOpState>("FCreateOpState", CreateCachedOpState)
.set_attr<nnvm::FInferShape>("FInferShape",
[](const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_shapes,
std::vector<TShape> *out_shapes) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op::DefaultSubgraphOpShapeHelper(op->GetForwardSym(), in_shapes, out_shapes);
})
.set_attr<nnvm::FInferType>("FInferType",
[](const nnvm::NodeAttrs& attrs,
std::vector<int> *in_types,
std::vector<int> *out_types) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op::DefaultSubgraphOpTypeHelper(op->GetForwardSym(), in_types, out_types);
})
.set_attr<FInferStorageType>("FInferStorageType",
[](const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* in_stypes,
std::vector<int>* out_stypes) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op::DefaultSubgraphOpStorageTypeHelper(op->GetForwardSym(),
dev_mask, dispatch_mode,
in_stypes, out_stypes);
})
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", CachedOpForward)
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", CachedOpForward)
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
[](const nnvm::NodeAttrs& attrs) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op::DefaultSubgraphOpMutableInputsHelper(op->GetForwardSym());
})
.set_attr<FResourceRequest>("FResourceRequest",
[](const nnvm::NodeAttrs& attrs) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op::DefaultSubgraphOpResourceRequestHelper(op->GetForwardSym());
})
.set_attr<FExecType>("FExecType", op::DefaultSubgraphOpExecType)
.add_argument("data", "NDArray-or-Symbol[]", "input data list");

NNVM_REGISTER_OP(_backward_CachedOp)
.set_num_inputs([](const NodeAttrs& attrs){
Expand All @@ -1184,6 +1372,9 @@ NNVM_REGISTER_OP(_backward_CachedOp)
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op->BackwardStorageType(attrs, dev_mask, dispatch_mode, in_attrs, out_attrs);
})
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", CachedOpBackward)
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", CachedOpBackward)
.set_attr<FExecType>("FExecType", op::DefaultSubgraphOpExecType)
.set_attr<bool>("TIsLayerOpBackward", true)
.set_attr<bool>("TIsBackward", true);

Expand Down
28 changes: 21 additions & 7 deletions src/imperative/cached_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ struct CachedOpConfig : public dmlc::Parameter<CachedOpConfig> {
bool static_shape;
nnvm::Tuple<uint32_t> data_indices;
nnvm::Tuple<uint32_t> param_indices;
std::string subgraph;
DMLC_DECLARE_PARAMETER(CachedOpConfig) {
DMLC_DECLARE_FIELD(static_alloc)
.set_default(false)
Expand All @@ -62,6 +63,9 @@ struct CachedOpConfig : public dmlc::Parameter<CachedOpConfig> {
DMLC_DECLARE_FIELD(param_indices)
.set_default(nnvm::Tuple<uint32_t>())
.describe("Position of parameters.");
DMLC_DECLARE_FIELD(subgraph)
.set_default(std::string(""))
.describe("JSON string of a subgraph.");
}
};

Expand All @@ -80,6 +84,10 @@ class CachedOp {
uint32_t num_backward_inputs() const {
return bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size();
}
uint32_t num_backward_outputs() const {
auto &idx = fwd_graph_.indexed_graph();
return idx.input_nodes().size() - idx.mutable_input_nodes().size();
}
std::vector<bool>& save_inputs() {
return save_inputs_;
}
Expand All @@ -102,20 +110,26 @@ class CachedOp {
const std::vector<NDArray*>& inputs,
const std::vector<OpReqType>& reqs,
const std::vector<NDArray*>& outputs);
// forward storage type inference
bool ForwardStorageType(
const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs);
// backward storage type inference
bool BackwardStorageType(
const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs);
std::vector<std::string> ListForwardInputNames() const {
nnvm::Symbol sym = GetForwardSym();
return sym.ListInputNames(nnvm::Symbol::kAll);
}
std::vector<std::string> ListForwardOutputNames() const {
nnvm::Symbol sym = GetForwardSym();
return sym.ListOutputNames();
}
nnvm::Symbol GetForwardSym() const {
nnvm::Symbol sym;
sym.outputs = fwd_graph_.outputs;
return sym;
}

private:
struct GraphInfo;
Expand Down
Loading

0 comments on commit 3caf2ca

Please sign in to comment.