Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
extend _CachedOp a regular operator.
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-da committed Aug 23, 2018
1 parent 31c5fbc commit 0b5df8b
Show file tree
Hide file tree
Showing 4 changed files with 518 additions and 16 deletions.
342 changes: 332 additions & 10 deletions src/imperative/cached_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,6 @@ OpStatePtr CachedOp::Forward(
return op_state;
}


void CachedOp::DynamicBackward(
const bool retain_graph,
const OpStatePtr& op_state,
Expand Down Expand Up @@ -1066,6 +1065,130 @@ void CachedOp::Backward(
Engine::Get()->set_bulk_size(prev_bulk_size);
}

struct CachedOpActualState {
std::shared_ptr<CachedOp> op;
OpStatePtr forward_state;

explicit CachedOpActualState(std::shared_ptr<CachedOp> op) {
this->op = op;
}
};

void CachedOpForward(const OpStatePtr& state_ptr,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CachedOpActualState &s = state_ptr.get_state<CachedOpActualState>();
std::vector<NDArray> in_bufs = inputs;
std::vector<NDArray> out_bufs = outputs;
std::vector<NDArray *> in_ptrs(in_bufs.size());
std::vector<NDArray *> out_ptrs(out_bufs.size());
for (size_t i = 0; i < in_ptrs.size(); i++)
in_ptrs[i] = &in_bufs[i];
for (size_t i = 0; i < out_ptrs.size(); i++)
out_ptrs[i] = &out_bufs[i];

// Set is_recording correct for the imperative executor.
bool orig_is_record;
if (ctx.need_grad)
orig_is_record = Imperative::Get()->set_is_recording(true);
else
orig_is_record = Imperative::Get()->is_recording();
// Set is_training correct for the imperative executor.
bool orig_is_train;
if (ctx.is_train)
orig_is_train = Imperative::Get()->set_is_training(true);
else
orig_is_train = Imperative::Get()->is_training();
s.forward_state = s.op->Forward(nullptr, in_ptrs, out_ptrs);
Imperative::Get()->set_is_training(orig_is_train);
Imperative::Get()->set_is_recording(orig_is_record);
// The arrays in out_ptrs may be changed by CachedOp.
// If it is, we need to copy data back.
for (size_t i = 0; i < out_bufs.size(); i++)
if (!out_bufs[i].IsSame(outputs[i]))
CopyFromTo(out_bufs[i], outputs[i]);
}

void CachedOpBackward(const OpStatePtr& state_ptr,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
using namespace nnvm;
using namespace imperative;
CachedOpActualState &s = state_ptr.get_state<CachedOpActualState>();
std::vector<NDArray> in_bufs = inputs;
std::vector<NDArray> out_bufs = outputs;
std::vector<NDArray *> in_ptrs;
std::vector<NDArray *> out_ptrs;
CHECK_EQ(s.op->num_backward_inputs(), inputs.size());
in_ptrs.reserve(s.op->num_backward_inputs());
out_ptrs.reserve(s.op->num_inputs());

const std::vector<bool> &save_inputs = s.op->save_inputs();
const std::vector<bool> &save_outputs = s.op->save_outputs();
size_t bwd_in_dep = s.op->num_inputs();
size_t bwd_out_dep = s.op->num_outputs();
CHECK(s.op->num_backward_inputs() > bwd_in_dep + bwd_out_dep);
size_t bwd_ograd_dep = s.op->num_backward_inputs() - bwd_in_dep - bwd_out_dep;

// Find inputs, outputs and ograds
auto ograds_begin = in_bufs.begin();
auto ograds_end = in_bufs.begin() + bwd_ograd_dep;
auto in_begin = ograds_end;
auto in_end = in_begin + bwd_in_dep;
auto out_begin = in_end;
auto out_end = in_bufs.end();

for (auto it = ograds_begin; it != ograds_end; it++)
in_ptrs.push_back(&(*it));

CHECK_EQ(save_inputs.size(), in_end - in_begin);
CHECK_EQ(s.op->num_outputs(), out_end - out_begin);
for (auto it = in_begin; it != in_end; it++) {
auto i = it - in_begin;
if (save_inputs[i])
in_ptrs.push_back(&(*it));
}
for (auto it = out_begin; it != out_end; it++) {
auto i = it - out_begin;
if (save_outputs[i])
in_ptrs.push_back(&(*it));
}
CHECK_EQ(in_ptrs.size(), s.op->num_backward_inputs());
for (size_t i = 0; i < out_bufs.size(); i++)
out_ptrs.push_back(&out_bufs[i]);
CHECK_EQ(out_ptrs.size(), s.op->num_backward_outputs());
// Set is_training correct for the imperative executor.
bool orig_is_train;
if (ctx.is_train)
orig_is_train = Imperative::Get()->set_is_training(true);
else
orig_is_train = Imperative::Get()->is_training();
// TODO(zhengda) is it right to use false here?
s.op->Backward(false, s.forward_state, in_ptrs, req, out_ptrs);
Imperative::Get()->set_is_training(orig_is_train);

// Clean up what we recorded.
s.forward_state.reset();

// The arrays in out_ptrs may be changed by CachedOp.
// If it is, we need to copy data back.
for (size_t i = 0; i < out_bufs.size(); i++)
if (!out_bufs[i].IsSame(outputs[i]))
CopyFromTo(out_bufs[i], outputs[i]);
}

OpStatePtr CreateCachedOpState(const NodeAttrs& attrs,
Context ctx,
const std::vector<TShape>& in_shapes,
const std::vector<int>& in_types) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return OpStatePtr::Create<CachedOpActualState>(op);
}

bool CachedOp::ForwardStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
Expand Down Expand Up @@ -1142,6 +1265,155 @@ bool CachedOp::BackwardStorageType(const nnvm::NodeAttrs& attrs,
return true;
}

bool CachedOp::ForwardInferShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_shapes,
std::vector<TShape> *out_shapes) {
using namespace exec;
nnvm::Graph g(fwd_graph_);
const auto& idx_g = g.indexed_graph();
CHECK_EQ(idx_g.input_nodes().size(), in_shapes->size());
CHECK_EQ(idx_g.outputs().size(), out_shapes->size());

// TODO(zhengda) we can cache the shape vector.
// Put the input and output shapes to the shape vector.
nnvm::ShapeVector shapes(idx_g.num_node_entries());
const auto &input_nids = idx_g.input_nodes();
CHECK_EQ(input_nids.size(), in_shapes->size());
for (size_t i = 0; i < in_shapes->size(); i++) {
auto eid = idx_g.entry_id(input_nids[i], 0);
shapes[eid] = in_shapes->at(i);
}
CHECK_EQ(g.outputs.size(), out_shapes->size());
for (size_t i = 0; i < out_shapes->size(); i++) {
auto eid = idx_g.entry_id(g.outputs[i]);
shapes[eid] = out_shapes->at(i);
}

// Infer shape of the graph.
g.attrs["shape"] = std::make_shared<dmlc::any>(std::move(shapes));
g = exec::InferShape(std::move(g));

// Copy the inferred shape back to the input shapes and the output shapes.
shapes = g.GetAttr<nnvm::ShapeVector>("shape");
// assign to in_shapes
for (size_t i = 0; i < in_shapes->size(); ++i) {
const auto eid = idx_g.entry_id(input_nids[i], 0);
SHAPE_ASSIGN_CHECK(*in_shapes, i, shapes[eid]);
}
// assign to out_shapes
for (size_t i = 0; i < g.outputs.size(); ++i) {
const auto eid = idx_g.entry_id(g.outputs[i]);
SHAPE_ASSIGN_CHECK(*out_shapes, i, shapes[eid]);
}
// Check if we have inferred the shapes correctly.
return g.GetAttr<size_t>("shape_num_unknown_nodes") == 0;
}

bool CachedOp::ForwardInferType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_types,
std::vector<int> *out_types) {
nnvm::Graph g(fwd_graph_);
const auto& idx_g = g.indexed_graph();
CHECK_EQ(idx_g.input_nodes().size(), in_types->size());
CHECK_EQ(idx_g.outputs().size(), out_types->size());

// TODO(zhengda) we can cache the shape vector.
// Put the input and output data types to the dtype vector.
nnvm::DTypeVector types(idx_g.num_node_entries(), -1);
const auto &input_nids = idx_g.input_nodes();
CHECK_EQ(input_nids.size(), in_types->size());
for (size_t i = 0; i < in_types->size(); i++) {
auto eid = idx_g.entry_id(input_nids[i], 0);
types[eid] = in_types->at(i);
}
CHECK_EQ(g.outputs.size(), out_types->size());
for (size_t i = 0; i < out_types->size(); i++) {
auto eid = idx_g.entry_id(g.outputs[i]);
types[eid] = out_types->at(i);
}

// Infer data type of the graph.
g.attrs["dtype"] = std::make_shared<dmlc::any>(std::move(types));
g = exec::InferType(std::move(g));

types = g.GetAttr<nnvm::DTypeVector>("dtype");
// assign to in_types
for (size_t i = 0; i < in_types->size(); ++i) {
const auto eid = idx_g.entry_id(input_nids[i], 0);
TYPE_ASSIGN_CHECK(*in_types, i, types[eid]);
}
// assign to out_types
for (size_t i = 0; i < g.outputs.size(); ++i) {
const auto eid = idx_g.entry_id(g.outputs[i]);
TYPE_ASSIGN_CHECK(*out_types, i, types[eid]);
}
// Check if we have inferred the dtypes correctly.
return g.GetAttr<size_t>("dtype_num_unknown_nodes") == 0;
}

std::vector<uint32_t> CachedOp::MutableInputs() const {
nnvm::Symbol sym = GetForwardSym();
const std::vector<std::string> input_names = sym.ListInputNames(nnvm::Symbol::kAll);
const std::vector<std::string> immutable_input_names =
sym.ListInputNames(nnvm::Symbol::kReadOnlyArgs);
const std::vector<std::string> mutable_input_names =
sym.ListInputNames(nnvm::Symbol::kAuxiliaryStates);
CHECK_EQ(immutable_input_names.size() + mutable_input_names.size(), input_names.size());
std::vector<uint32_t> ret;
size_t i1 = 0, i2 = 0;
for (size_t i = 0; i < input_names.size(); ++i) {
if (i1 < immutable_input_names.size() && input_names[i] == immutable_input_names[i1]) {
++i1;
} else {
CHECK(i2 < mutable_input_names.size());
CHECK_EQ(input_names[i], mutable_input_names[i2]);
++i2;
ret.push_back(i);
}
}
return ret;
}

std::vector<ResourceRequest> CachedOp::GetResourceRequest() const {
nnvm::Symbol sym = GetForwardSym();
static auto& fresource = Op::GetAttr<FResourceRequest>("FResourceRequest");
std::set<ResourceRequest::Type> resource_types;
DFSVisit(sym.outputs, [&](const nnvm::NodePtr& node) {
if (!node->is_variable() && fresource.count(node->op())) {
for (ResourceRequest& r : fresource[node->op()](node->attrs)){
resource_types.insert(r.type);
}
}
});
return std::vector<ResourceRequest>(resource_types.begin(), resource_types.end());
}

void CachedOpParamParser(nnvm::NodeAttrs* attrs) {
CachedOpConfig param;
try {
param.Init(attrs->dict);
} catch (const dmlc::ParamError& e) {
std::ostringstream os;
os << e.what();
os << ", in operator " << attrs->op->name << "("
<< "name=\"" << attrs->name << "\"";
for (const auto& k : attrs->dict) {
os << ", " << k.first << "=\"" << k.second << "\"";
}
os << ")";
throw dmlc::ParamError(os.str());
}
if (!param.subgraph.empty()) {
nnvm::Graph g = nnvm::pass::LoadJSON(param.subgraph);
CHECK(!g.outputs.empty());
nnvm::Symbol sym;
sym.outputs = g.outputs;
std::vector<std::pair<std::string, std::string> > flags;
for (auto it = attrs->dict.begin(); it != attrs->dict.end(); it++)
flags.emplace_back(it->first, it->second);
attrs->parsed = CachedOpPtr(new CachedOp(sym, flags));
}
}

NNVM_REGISTER_OP(_CachedOp)
.set_num_inputs([](const NodeAttrs& attrs) {
Expand All @@ -1152,19 +1424,63 @@ NNVM_REGISTER_OP(_CachedOp)
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op->num_outputs();
})
.set_attr<FInferStorageType>("FInferStorageType", [](const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op->ForwardStorageType(attrs, dev_mask, dispatch_mode, in_attrs, out_attrs);
})
.set_attr_parser(CachedOpParamParser)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(n->attrs.parsed);
return op->Gradient(n, ograds);
});
})
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const nnvm::NodeAttrs& attrs) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op->ListForwardInputNames();
})
.set_attr<nnvm::FListOutputNames>("FListOutputNames",
[](const nnvm::NodeAttrs& attrs) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op->ListForwardOutputNames();
})
.set_attr<FCreateOpState>("FCreateOpState", CreateCachedOpState)
.set_attr<nnvm::FInferShape>("FInferShape",
[](const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_shapes,
std::vector<TShape> *out_shapes) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op->ForwardInferShape(attrs, in_shapes, out_shapes);
})
.set_attr<nnvm::FInferType>("FInferType",
[](const nnvm::NodeAttrs& attrs,
std::vector<int> *in_types,
std::vector<int> *out_types) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op->ForwardInferType(attrs, in_types, out_types);
})
.set_attr<FInferStorageType>("FInferStorageType",
[](const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* in_stypes,
std::vector<int>* out_stypes) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op->ForwardStorageType(attrs, dev_mask, dispatch_mode, in_stypes, out_stypes);
})
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", CachedOpForward)
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", CachedOpForward)
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
[](const nnvm::NodeAttrs& attrs) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op->MutableInputs();
})
.set_attr<FResourceRequest>("FResourceRequest",
[](const nnvm::NodeAttrs& attrs) {
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op->GetResourceRequest();
})
.set_attr<FExecType>("FExecType",
[](const nnvm::NodeAttrs& attrs) {
return ExecType::kSubgraphExec;
})
.add_argument("data", "NDArray-or-Symbol[]", "input data list");

NNVM_REGISTER_OP(_backward_CachedOp)
.set_num_inputs([](const NodeAttrs& attrs){
Expand All @@ -1183,6 +1499,12 @@ NNVM_REGISTER_OP(_backward_CachedOp)
const CachedOpPtr& op = nnvm::get<CachedOpPtr>(attrs.parsed);
return op->BackwardStorageType(attrs, dev_mask, dispatch_mode, in_attrs, out_attrs);
})
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", CachedOpBackward)
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", CachedOpBackward)
.set_attr<FExecType>("FExecType",
[](const nnvm::NodeAttrs& attrs) {
return ExecType::kSubgraphExec;
})
.set_attr<bool>("TIsLayerOpBackward", true)
.set_attr<bool>("TIsBackward", true);

Expand Down
Loading

0 comments on commit 0b5df8b

Please sign in to comment.