Skip to content

Commit

Permalink
execute subgraph operators synchronously (apache#16)
Browse files Browse the repository at this point in the history
* execute subgraph operators correctly.

* add HasSubgraph attr.

* support subgraph op in imperative.

* rewrite running subgraph in main thread.

* execute subgraph op in imperative.

* remove need_grad.
  • Loading branch information
zheng-da authored and reminisce committed Jun 20, 2018
1 parent 7e96475 commit 8e3d4ed
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 6 deletions.
7 changes: 6 additions & 1 deletion include/mxnet/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,12 @@ enum class ExecType {
* In current implementation, copy operator is specially handled by executor.
* This flag is used for special case treatment and future extension of different copy ops.
*/
kCrossDeviceCopy
kCrossDeviceCopy,
/*!
* A subgraph execution should happen in the main thread, instead of
* in the execution engine.
*/
kSubgraphExec,
};

/*! \brief the dispatch mode of the operator */
Expand Down
3 changes: 3 additions & 0 deletions src/executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1614,6 +1614,9 @@ void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) {
CHECK_EQ(opnode.exec->in_array.size(), 1U);
CHECK_EQ(opnode.exec->out_array.size(), 1U);
CopyFromTo(opnode.exec->in_array[0], &(opnode.exec->out_array[0]));
} else if (opnode.exec->exec_type() == ExecType::kSubgraphExec) {
// If the node contains a subgraph, we can't execute it in the engine.
opnode.exec->Run(opnode.exec->op_ctx.run_ctx, false);
} else if (opnode.cached_opr != nullptr) {
bool profiling = profiler::Profiler::Get()->GetState() == profiler::Profiler::kRunning;
Engine::Get()->Push(opnode.cached_opr, opnode.ctx, 0, profiling);
Expand Down
21 changes: 16 additions & 5 deletions src/imperative/imperative_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,8 @@ inline void PushFComputeEx(const FComputeEx& fn,
}
};

if (exec_type == ExecType::kCrossDeviceCopy) {
if (exec_type == ExecType::kCrossDeviceCopy
|| exec_type == ExecType::kSubgraphExec) {
run(RunContext{ctx, nullptr});
} else {
CHECK(exec_type == ExecType::kSync);
Expand Down Expand Up @@ -475,12 +476,18 @@ inline void PushOperator(const OpStatePtr& state,
InvalidateOutputs(outputs, req);
#endif
fcompute_ex(state, opctx, inputs, req, outputs);
if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync) {
if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync
&& rctx.get_stream<gpu>()) {
rctx.get_stream<gpu>()->Wait();
}
};

if (exec_type == ExecType::kSync) {
// For operators with subgraphs, we need to invoke them in the main thread
// instead of the threaded engine.
if (exec_type == ExecType::kSubgraphExec) {
RunContext rctx{ctx, nullptr};
run(rctx, engine::CallbackOnComplete());
} else if (exec_type == ExecType::kSync) {
Engine::Get()->PushSync(
[=](RunContext rctx) { run(rctx, engine::CallbackOnComplete()); },
ctx, read_vars, write_vars, FnProperty::kNormal, 0,
Expand Down Expand Up @@ -519,12 +526,16 @@ inline void PushOperator(const OpStatePtr& state,
fcompute(state, opctx, input_blobs, tmp_req, output_blobs);
// post-fcompute fallback, cast to original storage type, if necessary
CastNonDefaultStorage(post_temp_src, post_temp_dst, opctx, is_gpu);
if (is_gpu && exec_type == ExecType::kSync) {
if (is_gpu && exec_type == ExecType::kSync
&& rctx.get_stream<gpu>()) {
rctx.get_stream<gpu>()->Wait();
}
};

if (exec_type == ExecType::kSync) {
if (exec_type == ExecType::kSubgraphExec) {
RunContext rctx{ctx, nullptr};
run(rctx, engine::CallbackOnComplete());
} else if (exec_type == ExecType::kSync) {
Engine::Get()->PushSync(
[=](RunContext rctx) {
run(rctx, engine::CallbackOnComplete());
Expand Down
3 changes: 3 additions & 0 deletions src/operator/subgraph/subgraph_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,9 @@ NNVM_REGISTER_OP(_subgraph_op)
.set_attr<FInferStorageType>("FInferStorageType", SubgraphOpStorageType)
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", SubgraphOpForward)
.set_attr<std::string>("key_var_num_args", "num_args")
.set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
return ExecType::kSubgraphExec;
})
.add_argument("data", "NDArray-or-Symbol[]", "input data list");

} // namespace op
Expand Down

0 comments on commit 8e3d4ed

Please sign in to comment.