Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-876] make CachedOp a normal operator #11641

Merged
merged 16 commits into from
Sep 23, 2018
251 changes: 217 additions & 34 deletions src/imperative/cached_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "../executor/exec_pass.h"
#include "../profiler/profiler.h"
#include "../operator/operator_common.h"
#include "../operator/subgraph/common.h"


namespace mxnet {
Expand Down Expand Up @@ -874,7 +875,6 @@ OpStatePtr CachedOp::Forward(
return op_state;
}


void CachedOp::DynamicBackward(
const bool retain_graph,
const OpStatePtr& op_state,
Expand Down Expand Up @@ -1067,34 +1067,145 @@ void CachedOp::Backward(
Engine::Get()->set_bulk_size(prev_bulk_size);
}

bool CachedOp::ForwardStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
using namespace imperative;
nnvm::Graph g(fwd_graph_);
const auto& idx = g.indexed_graph();
const auto &outputs = idx.outputs();
/*
* This is the operator state of CachedOp when CachedOp is used in the symbol
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please elaborate on the necessity of adding this data structure in the description.

* executor. This is different from the OpState returned by CachedOp::Forward.
* The main reason why we need this OpState is that CachedOp and the symbol executor
* maintain OpState differently. The symbol executor generates OpState in advance
* while CachedOp generates OpState after Forward is called. We need this data
* structure to keep the OpState generated by CachedOp::Forward and pass it to
* Backward.
*/
struct CachedOpActualState {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing class doc

std::shared_ptr<CachedOp> op;
OpStatePtr forward_state;

// Prepare stypes and contexts based on inputs
StorageTypeVector storage_type_inputs;
storage_type_inputs.reserve(in_attrs->size());
for (size_t i = 0; i < in_attrs->size(); ++i) {
storage_type_inputs.emplace_back(in_attrs->at(i));
explicit CachedOpActualState(std::shared_ptr<CachedOp> op) {
this->op = op;
}
exec::DevMaskVector dev_masks(idx.num_nodes(), dev_mask);
};

// Forward graph storage type inference
CheckAndInferStorageType(&g, std::move(dev_masks), std::move(storage_type_inputs), true);
// Retrieve result and set outputs
const auto& inferred_stypes = g.GetAttr<StorageTypeVector>("storage_type");
for (size_t i = 0; i < out_attrs->size(); i++) {
const auto eid = idx.entry_id(outputs[i]);
STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, i, inferred_stypes[eid]);
/*
* This is the forward computation when CachedOp is used as an operator in
* a symbol executor.
*/
void CachedOpForward(const OpStatePtr& state_ptr,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing a short documentation stating what's the intention and how it works

const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CachedOpActualState &s = state_ptr.get_state<CachedOpActualState>();
std::vector<NDArray> in_bufs = inputs;
std::vector<NDArray> out_bufs = outputs;
std::vector<NDArray *> in_ptrs(in_bufs.size());
std::vector<NDArray *> out_ptrs(out_bufs.size());
for (size_t i = 0; i < in_ptrs.size(); i++)
in_ptrs[i] = &in_bufs[i];
for (size_t i = 0; i < out_ptrs.size(); i++)
out_ptrs[i] = &out_bufs[i];

// Set is_recording correct for the imperative executor.
bool orig_is_record;
if (ctx.need_grad)
orig_is_record = Imperative::Get()->set_is_recording(true);
else
orig_is_record = Imperative::Get()->is_recording();
// Set is_training correct for the imperative executor.
bool orig_is_train;
if (ctx.is_train)
orig_is_train = Imperative::Get()->set_is_training(true);
else
orig_is_train = Imperative::Get()->is_training();
s.forward_state = s.op->Forward(nullptr, in_ptrs, out_ptrs);
Imperative::Get()->set_is_training(orig_is_train);
Imperative::Get()->set_is_recording(orig_is_record);
// The arrays in out_ptrs may be changed by CachedOp.
// If it is, we need to copy data back.
for (size_t i = 0; i < out_bufs.size(); i++)
if (!out_bufs[i].IsSame(outputs[i]))
CopyFromTo(out_bufs[i], outputs[i]);
}

/*
* This is the backward computation when CachedOp is used as an operator in
* a symbol executor.
*/
void CachedOpBackward(const OpStatePtr& state_ptr,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
using namespace nnvm;
using namespace imperative;
CachedOpActualState &s = state_ptr.get_state<CachedOpActualState>();
std::vector<NDArray> in_bufs = inputs;
std::vector<NDArray> out_bufs = outputs;
std::vector<NDArray *> in_ptrs;
std::vector<NDArray *> out_ptrs;
CHECK_EQ(s.op->num_backward_inputs(), inputs.size());
in_ptrs.reserve(s.op->num_backward_inputs());
out_ptrs.reserve(s.op->num_inputs());

const std::vector<bool> &save_inputs = s.op->save_inputs();
const std::vector<bool> &save_outputs = s.op->save_outputs();
size_t bwd_in_dep = s.op->num_inputs();
size_t bwd_out_dep = s.op->num_outputs();
CHECK(s.op->num_backward_inputs() > bwd_in_dep + bwd_out_dep);
size_t bwd_ograd_dep = s.op->num_backward_inputs() - bwd_in_dep - bwd_out_dep;

// Find inputs, outputs and ograds
auto ograds_begin = in_bufs.begin();
auto ograds_end = in_bufs.begin() + bwd_ograd_dep;
auto in_begin = ograds_end;
auto in_end = in_begin + bwd_in_dep;
auto out_begin = in_end;
auto out_end = in_bufs.end();

for (auto it = ograds_begin; it != ograds_end; it++)
in_ptrs.push_back(&(*it));

CHECK_EQ(save_inputs.size(), in_end - in_begin);
CHECK_EQ(s.op->num_outputs(), out_end - out_begin);
for (auto it = in_begin; it != in_end; it++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++it is potentially faster

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

really? Where is this documented?

Copy link
Contributor

@larroy larroy Jul 11, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zheng-da it's a well known thing for old C++ farts. It's in reference C++ books like http://www.cppstdlib.com/ or Stroustrup. https://stackoverflow.com/questions/1077026/incrementing-iterators-it-more-efficient-than-it

In most cases probably doesn't make a difference, specially for simple iterators where the iterator is just a pointer. That's why I said is potentially faster. It's more like a good idiomatic practice to always use preincrement.

https://stackoverflow.com/questions/1077026/incrementing-iterators-it-more-efficient-than-it

auto i = it - in_begin;
if (save_inputs[i])
in_ptrs.push_back(&(*it));
}
DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
return true;
for (auto it = out_begin; it != out_end; it++) {
auto i = it - out_begin;
if (save_outputs[i])
in_ptrs.push_back(&(*it));
}
CHECK_EQ(in_ptrs.size(), s.op->num_backward_inputs());
for (size_t i = 0; i < out_bufs.size(); i++)
out_ptrs.push_back(&out_bufs[i]);
CHECK_EQ(out_ptrs.size(), s.op->num_backward_outputs());
// Set is_training correct for the imperative executor.
bool orig_is_train;
if (ctx.is_train)
orig_is_train = Imperative::Get()->set_is_training(true);
else
orig_is_train = Imperative::Get()->is_training();
// TODO(zhengda) is it right to use false here?
s.op->Backward(false, s.forward_state, in_ptrs, req, out_ptrs);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add more comment on retain_graph=False

Imperative::Get()->set_is_training(orig_is_train);

// Clean up what we recorded.
s.forward_state.reset();

// The arrays in out_ptrs may be changed by CachedOp.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why would it be changed?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for updating the comments

// If it is, we need to copy data back.
for (size_t i = 0; i < out_bufs.size(); i++)
if (!out_bufs[i].IsSame(outputs[i]))
CopyFromTo(out_bufs[i], outputs[i]);
}

OpStatePtr CreateCachedOpState(const NodeAttrs& attrs,
Context ctx,
const std::vector<TShape>& in_shapes,
const std::vector<int>& in_types) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return OpStatePtr::Create<CachedOpActualState>(op);
}

bool CachedOp::BackwardStorageType(const nnvm::NodeAttrs& attrs,
Expand Down Expand Up @@ -1143,6 +1254,32 @@ bool CachedOp::BackwardStorageType(const nnvm::NodeAttrs& attrs,
return true;
}

void CachedOpParamParser(nnvm::NodeAttrs* attrs) {
CachedOpConfig param;
try {
param.Init(attrs->dict);
} catch (const dmlc::ParamError& e) {
std::ostringstream os;
os << e.what();
os << ", in operator " << attrs->op->name << "("
<< "name=\"" << attrs->name << "\"";
for (const auto& k : attrs->dict) {
os << ", " << k.first << "=\"" << k.second << "\"";
}
os << ")";
throw dmlc::ParamError(os.str());
}
if (!param.subgraph.empty()) {
nnvm::Graph g = nnvm::pass::LoadJSON(param.subgraph);
CHECK(!g.outputs.empty());
nnvm::Symbol sym;
sym.outputs = g.outputs;
std::vector<std::pair<std::string, std::string> > flags;
for (auto it = attrs->dict.begin(); it != attrs->dict.end(); it++)
flags.emplace_back(it->first, it->second);
attrs->parsed = CachedOpPtr(new CachedOp(sym, flags));
}
}

NNVM_REGISTER_OP(_CachedOp)
.set_num_inputs([](const NodeAttrs& attrs) {
Expand All @@ -1153,19 +1290,62 @@ NNVM_REGISTER_OP(_CachedOp)
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op->num_outputs();
})
.set_attr<FInferStorageType>("FInferStorageType", [](const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op->ForwardStorageType(attrs, dev_mask, dispatch_mode, in_attrs, out_attrs);
})
.set_attr_parser(CachedOpParamParser)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(n->attrs.parsed);
return op->Gradient(n, ograds);
});
})
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const nnvm::NodeAttrs& attrs) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op->ListForwardInputNames();
})
.set_attr<nnvm::FListOutputNames>("FListOutputNames",
[](const nnvm::NodeAttrs& attrs) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op->ListForwardOutputNames();
})
.set_attr<FCreateOpState>("FCreateOpState", CreateCachedOpState)
.set_attr<nnvm::FInferShape>("FInferShape",
[](const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_shapes,
std::vector<TShape> *out_shapes) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op::DefaultSubgraphOpShapeHelper(op->GetForwardSym(), in_shapes, out_shapes);
})
.set_attr<nnvm::FInferType>("FInferType",
[](const nnvm::NodeAttrs& attrs,
std::vector<int> *in_types,
std::vector<int> *out_types) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op::DefaultSubgraphOpTypeHelper(op->GetForwardSym(), in_types, out_types);
})
.set_attr<FInferStorageType>("FInferStorageType",
[](const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* in_stypes,
std::vector<int>* out_stypes) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op::DefaultSubgraphOpStorageTypeHelper(op->GetForwardSym(),
dev_mask, dispatch_mode,
in_stypes, out_stypes);
})
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", CachedOpForward)
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", CachedOpForward)
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
[](const nnvm::NodeAttrs& attrs) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op::DefaultSubgraphOpMutableInputsHelper(op->GetForwardSym());
})
.set_attr<FResourceRequest>("FResourceRequest",
[](const nnvm::NodeAttrs& attrs) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op::DefaultSubgraphOpResourceRequestHelper(op->GetForwardSym());
})
.set_attr<FExecType>("FExecType", op::DefaultSubgraphOpExecType)
.add_argument("data", "NDArray-or-Symbol[]", "input data list");

NNVM_REGISTER_OP(_backward_CachedOp)
.set_num_inputs([](const NodeAttrs& attrs){
Expand All @@ -1184,6 +1364,9 @@ NNVM_REGISTER_OP(_backward_CachedOp)
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op->BackwardStorageType(attrs, dev_mask, dispatch_mode, in_attrs, out_attrs);
})
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", CachedOpBackward)
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", CachedOpBackward)
.set_attr<FExecType>("FExecType", op::DefaultSubgraphOpExecType)
.set_attr<bool>("TIsLayerOpBackward", true)
.set_attr<bool>("TIsBackward", true);

Expand Down
28 changes: 21 additions & 7 deletions src/imperative/cached_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ struct CachedOpConfig : public dmlc::Parameter<CachedOpConfig> {
bool static_shape;
nnvm::Tuple<uint32_t> data_indices;
nnvm::Tuple<uint32_t> param_indices;
std::string subgraph;
DMLC_DECLARE_PARAMETER(CachedOpConfig) {
DMLC_DECLARE_FIELD(static_alloc)
.set_default(false)
Expand All @@ -62,6 +63,9 @@ struct CachedOpConfig : public dmlc::Parameter<CachedOpConfig> {
DMLC_DECLARE_FIELD(param_indices)
.set_default(nnvm::Tuple<uint32_t>())
.describe("Position of parameters.");
DMLC_DECLARE_FIELD(subgraph)
.set_default(std::string(""))
.describe("JSON string of a subgraph.");
}
};

Expand All @@ -80,6 +84,10 @@ class CachedOp {
uint32_t num_backward_inputs() const {
return bwd_ograd_dep_.size() + bwd_in_dep_.size() + bwd_out_dep_.size();
}
uint32_t num_backward_outputs() const {
auto &idx = fwd_graph_.indexed_graph();
return idx.input_nodes().size() - idx.mutable_input_nodes().size();
}
std::vector<bool>& save_inputs() {
return save_inputs_;
}
Expand All @@ -102,20 +110,26 @@ class CachedOp {
const std::vector<NDArray*>& inputs,
const std::vector<OpReqType>& reqs,
const std::vector<NDArray*>& outputs);
// forward storage type inference
bool ForwardStorageType(
const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs);
// backward storage type inference
bool BackwardStorageType(
const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs);
std::vector<std::string> ListForwardInputNames() const {
nnvm::Symbol sym = GetForwardSym();
return sym.ListInputNames(nnvm::Symbol::kAll);
}
std::vector<std::string> ListForwardOutputNames() const {
nnvm::Symbol sym = GetForwardSym();
return sym.ListOutputNames();
}
nnvm::Symbol GetForwardSym() const {
nnvm::Symbol sym;
sym.outputs = fwd_graph_.outputs;
return sym;
}

private:
struct GraphInfo;
Expand Down
12 changes: 6 additions & 6 deletions src/operator/operator_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ inline bool dispatch_mode_assign(DispatchMode *y, const DispatchMode& x) {
*/
#define SHAPE_ASSIGN_CHECK(shape_array, index, shape) \
{ \
if (!shape_assign(&(shape_array)[index], TShape(shape))) { \
if (!::mxnet::op::shape_assign(&(shape_array)[index], TShape(shape))) { \
std::ostringstream os; \
os << "Shape inconsistent, Provided = " << (shape_array)[index] << ','\
<< " inferred shape=" << shape; \
Expand All @@ -238,11 +238,11 @@ inline bool dispatch_mode_assign(DispatchMode *y, const DispatchMode& x) {
*/
#define TYPE_ASSIGN_CHECK(type_array, index, type) \
{ \
if (!type_assign(&(type_array)[index], type)) { \
if (!::mxnet::op::type_assign(&(type_array)[index], type)) { \
std::ostringstream os; \
os << "Type inconsistent, Provided = " \
<< type_string((type_array)[index]) << ',' \
<< " inferred type = " << type_string(type); \
<< ::mxnet::op::type_string((type_array)[index]) << ',' \
<< " inferred type = " << ::mxnet::op::type_string(type); \
throw ::mxnet::op::InferTypeError(os.str(), index); \
} \
}
Expand Down Expand Up @@ -291,8 +291,8 @@ inline bool dispatch_mode_assign(DispatchMode *y, const DispatchMode& x) {
#define UNIFORM_TYPE_CHECK(type, expected, arg) \
{ \
CHECK_EQ(type, expected) << "This layer requires uniform type. " \
<< "Expected '" << type_string(expected) \
<< "' v.s. given '" << type_string(type) \
<< "Expected '" << ::mxnet::op::type_string(expected) \
<< "' v.s. given '" << ::mxnet::op::type_string(type) \
<< "' at '" << arg << "'"; \
}

Expand Down
Loading