Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Re-enable static cached_op optimization #14931

Merged
merged 1 commit into from
May 14, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/executor/attach_op_execs_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ class FComputeExExecutor : public OpExecutor {
ExecType exec_type_;
};

void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i) {
void CreateOpExecs(const Graph& g, OpExecVector* p_ret, OpStateVector* p_state, size_t i) {
using nnvm::DTypeVector;
using mxnet::ShapeVector;
using nnvm::FMutateInputs;
Expand Down Expand Up @@ -302,6 +302,10 @@ void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i) {

OpStatePtr state = fcreate_op_state[op](
inode.source->attrs, vctx[i], ishape, itype);
if (p_state) {
CHECK_GT(p_state->size(), i);
p_state->at(i) = state;
}
FStatefulComputeEx fcompute_ex = common::GetFCompute<FStatefulComputeEx>(
op, "FStatefulComputeEx", vctx[i]);
// FStatefulComputeEx is dispatched only when dispatch_mode is DispatchMode::kFComputeEx
Expand Down Expand Up @@ -359,7 +363,7 @@ Graph AttachOpExecs(Graph g) {
const auto& idx = g.indexed_graph();
OpExecVector ret(idx.num_nodes());
for (size_t i = 0; i < idx.num_nodes(); ++i) {
CreateOpExecs(g, &ret, i);
CreateOpExecs(g, &ret, nullptr, i);
}
g.attrs["op_execs"] = std::make_shared<nnvm::any>(ret);
return g;
Expand Down
9 changes: 8 additions & 1 deletion src/executor/exec_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ class OpExecutor {
*/
using OpExecVector = std::vector<std::shared_ptr<OpExecutor> >;

/*!
* \brief per node vector of operator states.
* \note stored under attribute "op_states"
*/
using OpStateVector = std::vector<OpStatePtr>;

/*!
* \brief per node context vector
* \node stored under "context"
Expand All @@ -115,9 +121,10 @@ using DevMaskVector = std::vector<int>;
*
* \param g input graph
* \param p_ret OpExecVector for input and output
* \param p_state OpStateVector if it has.
* \param i the id of the node
*/
void CreateOpExecs(const Graph& g, OpExecVector* p_ret, size_t i);
void CreateOpExecs(const Graph& g, OpExecVector* p_ret, OpStateVector* p_state, size_t i);
/*!
* \brief Attach OpExecutor to the graph attributes.
*
Expand Down
10 changes: 6 additions & 4 deletions src/imperative/cached_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ bool CachedOp::CheckDynamicShapeExists(const Context& default_ctx,
CheckAndInferShape(&g, std::move(shape_inputs), true,
{0, 0}, {0, 0},
&contain_dynamic_shape);
if (erase_result) {
if (contain_dynamic_shape && erase_result) {
g.attrs.erase("shape");
g.attrs.erase("shape_inputs");
}
Expand Down Expand Up @@ -603,7 +603,7 @@ void CachedOp::StaticInitExec(
}
} else {
for (size_t i = start_nid; i < end_nid; ++i) {
exec::CreateOpExecs(g, &state.execs, i);
exec::CreateOpExecs(g, &state.execs, &state.op_states, i);
}
exec::AttachOpResources(g, state.execs, start_nid, end_nid);

Expand Down Expand Up @@ -705,8 +705,10 @@ void CachedOp::StaticRunOps(
arg_shapes.emplace_back(ndinput->shape());
arg_dtypes.emplace_back(ndinput->dtype());
}
state.op_states[i] = createop[node.source->op()](
node.source->attrs, default_ctx, arg_shapes, arg_dtypes);
if (!config_.static_shape) {
Copy link
Contributor Author

@ZhennanQin ZhennanQin May 12, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eric-haibin-lin This is the only changed line to fix this PR. We should re-create op state if static_shape is false.

state.op_states[i] =
createop[node.source->op()](node.source->attrs, default_ctx, arg_shapes, arg_dtypes);
}
Imperative::Get()->InvokeOp(
default_ctx, node.source->attrs, ndinputs, ndoutputs, req,
dispatch_mode, state.op_states[i]);
Expand Down
26 changes: 12 additions & 14 deletions src/imperative/imperative_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -595,23 +595,21 @@ inline bool CheckAndInferShape(nnvm::Graph* p_g, mxnet::ShapeVector&& shapes,
*contain_unknown = false;
}
nnvm::Graph& g = *p_g;
if (use_inputs) {
if (g.attrs.count("shape_inputs") &&
g.GetAttr<mxnet::ShapeVector>("shape_inputs") == shapes) return true;
} else if (g.attrs.count("shape")) {
if (g.attrs.count("shape")) {
const auto& prev_shapes = g.GetAttr<mxnet::ShapeVector>("shape");
CHECK_EQ(prev_shapes.size(), shapes.size());
bool match = true;
for (size_t i = 0; i < shapes.size(); ++i) {
if (i == entry_range.first) {
i = entry_range.second;
if (i >= shapes.size()) break;
if (prev_shapes.size() == shapes.size()) {
bool match = true;
for (size_t i = 0; i < shapes.size(); ++i) {
if (i == entry_range.first) {
i = entry_range.second;
if (i >= shapes.size()) break;
}
if (shapes[i] == prev_shapes[i]) continue;
match = false;
break;
}
if (shapes[i] == prev_shapes[i]) continue;
match = false;
break;
if (match) return true;
}
if (match) return true;
}
g.attrs.erase("shape");
g.attrs.erase("shape_inputs");
Expand Down