Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-876] make CachedOp a normal operator #11641

Merged
merged 16 commits into from
Sep 23, 2018
259 changes: 225 additions & 34 deletions src/imperative/cached_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "../executor/exec_pass.h"
#include "../profiler/profiler.h"
#include "../operator/operator_common.h"
#include "../operator/subgraph/common.h"


namespace mxnet {
Expand Down Expand Up @@ -874,7 +875,6 @@ OpStatePtr CachedOp::Forward(
return op_state;
}


void CachedOp::DynamicBackward(
const bool retain_graph,
const OpStatePtr& op_state,
Expand Down Expand Up @@ -1067,34 +1067,153 @@ void CachedOp::Backward(
Engine::Get()->set_bulk_size(prev_bulk_size);
}

bool CachedOp::ForwardStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
using namespace imperative;
nnvm::Graph g(fwd_graph_);
const auto& idx = g.indexed_graph();
const auto &outputs = idx.outputs();
/*
* This is the operator state of CachedOp when CachedOp is used in the symbol
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please elaborate on the necessity of adding this data structure in the description.

* executor. This is different from the OpState returned by CachedOp::Forward.
* The main reason why we need this OpState is that CachedOp and the symbol executor
* maintain OpState differently. The symbol executor generates OpState in advance
* while CachedOp generates OpState after Forward is called. We need this data
* structure to keep the OpState generated by CachedOp::Forward and pass it to
* Backward.
*/
struct CachedOpActualState {
std::shared_ptr<CachedOp> op;
OpStatePtr forward_state;

// Prepare stypes and contexts based on inputs
StorageTypeVector storage_type_inputs;
storage_type_inputs.reserve(in_attrs->size());
for (size_t i = 0; i < in_attrs->size(); ++i) {
storage_type_inputs.emplace_back(in_attrs->at(i));
explicit CachedOpActualState(std::shared_ptr<CachedOp> op) {
this->op = op;
}
exec::DevMaskVector dev_masks(idx.num_nodes(), dev_mask);
};

// Forward graph storage type inference
CheckAndInferStorageType(&g, std::move(dev_masks), std::move(storage_type_inputs), true);
// Retrieve result and set outputs
const auto& inferred_stypes = g.GetAttr<StorageTypeVector>("storage_type");
for (size_t i = 0; i < out_attrs->size(); i++) {
const auto eid = idx.entry_id(outputs[i]);
STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, i, inferred_stypes[eid]);
/*
* This is the forward computation when CachedOp is used as an operator in
* a symbol executor.
*/
void CachedOpForward(const OpStatePtr& state_ptr,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CachedOpActualState &s = state_ptr.get_state<CachedOpActualState>();
std::vector<NDArray> in_bufs = inputs;
std::vector<NDArray> out_bufs = outputs;
std::vector<NDArray *> in_ptrs(in_bufs.size());
std::vector<NDArray *> out_ptrs(out_bufs.size());
for (size_t i = 0; i < in_ptrs.size(); i++)
in_ptrs[i] = &in_bufs[i];
for (size_t i = 0; i < out_ptrs.size(); i++)
out_ptrs[i] = &out_bufs[i];

// Set is_recording correct for the imperative executor.
bool orig_is_record;
if (ctx.need_grad)
orig_is_record = Imperative::Get()->set_is_recording(true);
else
orig_is_record = Imperative::Get()->is_recording();
// Set is_training correct for the imperative executor.
bool orig_is_train;
if (ctx.is_train)
orig_is_train = Imperative::Get()->set_is_training(true);
else
orig_is_train = Imperative::Get()->is_training();
s.forward_state = s.op->Forward(nullptr, in_ptrs, out_ptrs);
Imperative::Get()->set_is_training(orig_is_train);
Imperative::Get()->set_is_recording(orig_is_record);
// The arrays in out_ptrs may be changed by CachedOp.
// If it is, we need to copy data back.
for (size_t i = 0; i < out_bufs.size(); i++)
if (!out_bufs[i].IsSame(outputs[i]))
CopyFromTo(out_bufs[i], outputs[i]);
}

/*
* This is the backward computation when CachedOp is used as an operator in
* a symbol executor.
*/
void CachedOpBackward(const OpStatePtr& state_ptr,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
using namespace nnvm;
using namespace imperative;
CachedOpActualState &s = state_ptr.get_state<CachedOpActualState>();
std::vector<NDArray> in_bufs = inputs;
std::vector<NDArray> out_bufs = outputs;
std::vector<NDArray *> in_ptrs;
std::vector<NDArray *> out_ptrs;
CHECK_EQ(s.op->num_backward_inputs(), inputs.size());
in_ptrs.reserve(s.op->num_backward_inputs());
out_ptrs.reserve(s.op->num_inputs());

const std::vector<bool> &save_inputs = s.op->save_inputs();
const std::vector<bool> &save_outputs = s.op->save_outputs();
size_t bwd_in_dep = s.op->num_inputs();
size_t bwd_out_dep = s.op->num_outputs();
CHECK(s.op->num_backward_inputs() > bwd_in_dep + bwd_out_dep);
size_t bwd_ograd_dep = s.op->num_backward_inputs() - bwd_in_dep - bwd_out_dep;

// Find inputs, outputs and ograds
auto ograds_begin = in_bufs.begin();
auto ograds_end = in_bufs.begin() + bwd_ograd_dep;
auto in_begin = ograds_end;
auto in_end = in_begin + bwd_in_dep;
auto out_begin = in_end;
auto out_end = in_bufs.end();

for (auto it = ograds_begin; it != ograds_end; it++)
in_ptrs.push_back(&(*it));

CHECK_EQ(save_inputs.size(), in_end - in_begin);
CHECK_EQ(s.op->num_outputs(), out_end - out_begin);
for (auto it = in_begin; it != in_end; it++) {
auto i = it - in_begin;
if (save_inputs[i])
in_ptrs.push_back(&(*it));
}
DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
return true;
for (auto it = out_begin; it != out_end; it++) {
auto i = it - out_begin;
if (save_outputs[i])
in_ptrs.push_back(&(*it));
}
CHECK_EQ(in_ptrs.size(), s.op->num_backward_inputs());
for (size_t i = 0; i < out_bufs.size(); i++)
out_ptrs.push_back(&out_bufs[i]);
CHECK_EQ(out_ptrs.size(), s.op->num_backward_outputs());
// Set is_training correct for the imperative executor.
bool orig_is_train;
if (ctx.is_train)
orig_is_train = Imperative::Get()->set_is_training(true);
else
orig_is_train = Imperative::Get()->is_training();
// TODO(zhengda) CachedOp supports recording computation when running
// the backward path. This is necessary if we want to support the second-order
// differentiation. However, MXNet operator doesn't have an interface to
// pass a flag to determine whether to record computation inside an operator.
// Let's use false here for now and design a solution when the second-order
// differentiation is supported.
s.op->Backward(false, s.forward_state, in_ptrs, req, out_ptrs);
Imperative::Get()->set_is_training(orig_is_train);

// Clean up what we recorded.
s.forward_state.reset();

// The arrays in out_ptrs may be changed by CachedOp.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why would it be changed?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for updating the comments

// If it is, we need to copy data back.
// For example, when the inputs and outputs share the same NDArrays,
// the outputs will be replaced by inputs.
// https://github.com/apache/incubator-mxnet/blob/v1.2.0/src/imperative/cached_op.cc#L385
for (size_t i = 0; i < out_bufs.size(); i++)
if (!out_bufs[i].IsSame(outputs[i]))
CopyFromTo(out_bufs[i], outputs[i]);
}

OpStatePtr CreateCachedOpState(const NodeAttrs& attrs,
Context ctx,
const std::vector<TShape>& in_shapes,
const std::vector<int>& in_types) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return OpStatePtr::Create<CachedOpActualState>(op);
}

bool CachedOp::BackwardStorageType(const nnvm::NodeAttrs& attrs,
Expand Down Expand Up @@ -1143,6 +1262,32 @@ bool CachedOp::BackwardStorageType(const nnvm::NodeAttrs& attrs,
return true;
}

void CachedOpParamParser(nnvm::NodeAttrs* attrs) {
CachedOpConfig param;
try {
param.Init(attrs->dict);
} catch (const dmlc::ParamError& e) {
std::ostringstream os;
os << e.what();
os << ", in operator " << attrs->op->name << "("
<< "name=\"" << attrs->name << "\"";
for (const auto& k : attrs->dict) {
os << ", " << k.first << "=\"" << k.second << "\"";
}
os << ")";
throw dmlc::ParamError(os.str());
}
if (!param.subgraph.empty()) {
nnvm::Graph g = nnvm::pass::LoadJSON(param.subgraph);
CHECK(!g.outputs.empty());
nnvm::Symbol sym;
sym.outputs = g.outputs;
std::vector<std::pair<std::string, std::string> > flags;
for (auto it = attrs->dict.begin(); it != attrs->dict.end(); it++)
flags.emplace_back(it->first, it->second);
attrs->parsed = CachedOpPtr(new CachedOp(sym, flags));
}
}

NNVM_REGISTER_OP(_CachedOp)
.set_num_inputs([](const NodeAttrs& attrs) {
Expand All @@ -1153,19 +1298,62 @@ NNVM_REGISTER_OP(_CachedOp)
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op->num_outputs();
})
.set_attr<FInferStorageType>("FInferStorageType", [](const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op->ForwardStorageType(attrs, dev_mask, dispatch_mode, in_attrs, out_attrs);
})
.set_attr_parser(CachedOpParamParser)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(n->attrs.parsed);
return op->Gradient(n, ograds);
});
})
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const nnvm::NodeAttrs& attrs) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op->ListForwardInputNames();
})
.set_attr<nnvm::FListOutputNames>("FListOutputNames",
[](const nnvm::NodeAttrs& attrs) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op->ListForwardOutputNames();
})
.set_attr<FCreateOpState>("FCreateOpState", CreateCachedOpState)
.set_attr<nnvm::FInferShape>("FInferShape",
[](const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_shapes,
std::vector<TShape> *out_shapes) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op::DefaultSubgraphOpShapeHelper(op->GetForwardSym(), in_shapes, out_shapes);
})
.set_attr<nnvm::FInferType>("FInferType",
[](const nnvm::NodeAttrs& attrs,
std::vector<int> *in_types,
std::vector<int> *out_types) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op::DefaultSubgraphOpTypeHelper(op->GetForwardSym(), in_types, out_types);
})
.set_attr<FInferStorageType>("FInferStorageType",
[](const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* in_stypes,
std::vector<int>* out_stypes) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op::DefaultSubgraphOpStorageTypeHelper(op->GetForwardSym(),
dev_mask, dispatch_mode,
in_stypes, out_stypes);
})
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", CachedOpForward)
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", CachedOpForward)
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
[](const nnvm::NodeAttrs& attrs) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op::DefaultSubgraphOpMutableInputsHelper(op->GetForwardSym());
})
.set_attr<FResourceRequest>("FResourceRequest",
[](const nnvm::NodeAttrs& attrs) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op::DefaultSubgraphOpResourceRequestHelper(op->GetForwardSym());
})
.set_attr<FExecType>("FExecType", op::DefaultSubgraphOpExecType)
.add_argument("data", "NDArray-or-Symbol[]", "input data list");

NNVM_REGISTER_OP(_backward_CachedOp)
.set_num_inputs([](const NodeAttrs& attrs){
Expand All @@ -1184,6 +1372,9 @@ NNVM_REGISTER_OP(_backward_CachedOp)
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op->BackwardStorageType(attrs, dev_mask, dispatch_mode, in_attrs, out_attrs);
})
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", CachedOpBackward)
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", CachedOpBackward)
.set_attr<FExecType>("FExecType", op::DefaultSubgraphOpExecType)
.set_attr<bool>("TIsLayerOpBackward", true)
.set_attr<bool>("TIsBackward", true);

Expand Down
28 changes: 21 additions & 7 deletions src/imperative/cached_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ struct CachedOpConfig : public dmlc::Parameter<CachedOpConfig> {
bool static_shape;
nnvm::Tuple<uint32_t> data_indices;
nnvm::Tuple<uint32_t> param_indices;
std::string subgraph;
DMLC_DECLARE_PARAMETER(CachedOpConfig) {
DMLC_DECLARE_FIELD(static_alloc)
.set_default(false)
Expand All @@ -62,6 +63,9 @@ struct CachedOpConfig : public dmlc::Parameter<CachedOpConfig> {
DMLC_DECLARE_FIELD(param_indices)
.set_default(nnvm::Tuple<uint32_t>())
.describe("Position of parameters.");
DMLC_DECLARE_FIELD(subgraph)
.set_default(std::string(""))
.describe("JSON string of a subgraph.");
}
};

Expand All @@ -80,6 +84,10 @@ class CachedOp {
uint32_t num_backward_inputs() const {
return bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size();
}
uint32_t num_backward_outputs() const {
auto &idx = fwd_graph_.indexed_graph();
return idx.input_nodes().size() - idx.mutable_input_nodes().size();
}
std::vector<bool>& save_inputs() {
return save_inputs_;
}
Expand All @@ -102,20 +110,26 @@ class CachedOp {
const std::vector<NDArray*>& inputs,
const std::vector<OpReqType>& reqs,
const std::vector<NDArray*>& outputs);
// forward storage type inference
bool ForwardStorageType(
const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs);
// backward storage type inference
bool BackwardStorageType(
const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs);
std::vector<std::string> ListForwardInputNames() const {
nnvm::Symbol sym = GetForwardSym();
return sym.ListInputNames(nnvm::Symbol::kAll);
}
std::vector<std::string> ListForwardOutputNames() const {
nnvm::Symbol sym = GetForwardSym();
return sym.ListOutputNames();
}
nnvm::Symbol GetForwardSym() const {
nnvm::Symbol sym;
sym.outputs = fwd_graph_.outputs;
return sym;
}

private:
struct GraphInfo;
Expand Down
Loading