Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Improve cached_op performance for static mode #14785

Merged
merged 4 commits into from
Apr 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/executor/attach_op_execs_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ class FComputeExExecutor : public OpExecutor {
ExecType exec_type_;
};

void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i) {
void CreateOpExecs(const Graph& g, OpExecVector* p_ret, OpStateVector* p_state, size_t i) {
using nnvm::DTypeVector;
using mxnet::ShapeVector;
using nnvm::FMutateInputs;
Expand Down Expand Up @@ -302,6 +302,10 @@ void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i) {

OpStatePtr state = fcreate_op_state[op](
inode.source->attrs, vctx[i], ishape, itype);
if (p_state) {
CHECK_GT(p_state->size(), i);
p_state->at(i) = state;
}
FStatefulComputeEx fcompute_ex = common::GetFCompute<FStatefulComputeEx>(
op, "FStatefulComputeEx", vctx[i]);
// FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx
Expand Down Expand Up @@ -359,7 +363,7 @@ Graph AttachOpExecs(Graph g) {
const auto& idx = g.indexed_graph();
OpExecVector ret(idx.num_nodes());
for (size_t i = 0; i < idx.num_nodes(); ++i) {
CreateOpExecs(g, &ret, i);
CreateOpExecs(g, &ret, nullptr, i);
}
g.attrs["op_execs"] = std::make_shared<nnvm::any>(ret);
return g;
Expand Down
9 changes: 8 additions & 1 deletion src/executor/exec_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ class OpExecutor {
*/
using OpExecVector = std::vector<std::shared_ptr<OpExecutor> >;

/*!
* \brief per node vector of operator states.
* \note stored under attribute "op_states"
*/
using OpStateVector = std::vector<OpStatePtr>;

/*!
* \brief per node context vector
* \node stored under "context"
Expand All @@ -115,9 +121,10 @@ using DevMaskVector = std::vector<int>;
*
* \param g input graph
* \param p_ret OpExecVector for input and output
* \param p_state OpStateVector if it has.
* \param i the id of the node
*/
void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i);
void CreateOpExecs(const Graph& g, OpExecVector* p_ret, OpStateVector* p_state, size_t i);
/*!
* \brief Attach OpExecutor to the graph attributes.
*
Expand Down
10 changes: 6 additions & 4 deletions src/imperative/cached_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ bool CachedOp::CheckDynamicShapeExists(const Context& default_ctx,
CheckAndInferShape(&g, std::move(shape_inputs), true,
{0, 0}, {0, 0},
&contain_dynamic_shape);
if (erase_result) {
if (contain_dynamic_shape && erase_result) {
g.attrs.erase("shape");
g.attrs.erase("shape_inputs");
}
Expand Down Expand Up @@ -603,7 +603,7 @@ void CachedOp::StaticInitExec(
}
} else {
for (size_t i = start_nid; i < end_nid; ++i) {
exec::CreateOpExecs(g, &state.execs, i);
exec::CreateOpExecs(g, &state.execs, &state.op_states, i);
}
exec::AttachOpResources(g, state.execs, start_nid, end_nid);

Expand Down Expand Up @@ -705,8 +705,10 @@ void CachedOp::StaticRunOps(
arg_shapes.emplace_back(ndinput->shape());
arg_dtypes.emplace_back(ndinput->dtype());
}
state.op_states[i] = createop[node.source->op()](
node.source->attrs, default_ctx, arg_shapes, arg_dtypes);
if (!state.op_states[i]) {
state.op_states[i] =
createop[node.source->op()](node.source->attrs, default_ctx, arg_shapes, arg_dtypes);
}
Imperative::Get()->InvokeOp(
default_ctx, node.source->attrs, ndinputs, ndoutputs, req,
dispatch_mode, state.op_states[i]);
Expand Down
26 changes: 12 additions & 14 deletions src/imperative/imperative_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -595,23 +595,21 @@ inline bool CheckAndInferShape(nnvm::Graph* p_g, mxnet::ShapeVector&& shapes,
*contain_unknown = false;
}
nnvm::Graph& g = *p_g;
if (use_inputs) {
if (g.attrs.count("shape_inputs") &&
g.GetAttr<mxnet::ShapeVector>("shape_inputs") == shapes) return true;
} else if (g.attrs.count("shape")) {
if (g.attrs.count("shape")) {
const auto& prev_shapes = g.GetAttr<mxnet::ShapeVector>("shape");
CHECK_EQ(prev_shapes.size(), shapes.size());
bool match = true;
for (size_t i = 0; i < shapes.size(); ++i) {
if (i == entry_range.first) {
i = entry_range.second;
if (i >= shapes.size()) break;
if (prev_shapes.size() == shapes.size()) {
bool match = true;
for (size_t i = 0; i < shapes.size(); ++i) {
if (i == entry_range.first) {
i = entry_range.second;
if (i >= shapes.size()) break;
}
if (shapes[i] == prev_shapes[i]) continue;
match = false;
break;
}
if (shapes[i] == prev_shapes[i]) continue;
match = false;
break;
if (match) return true;
}
if (match) return true;
}
g.attrs.erase("shape");
g.attrs.erase("shape_inputs");
Expand Down