diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h index f4694efad297..ebe82491eb6e 100644 --- a/include/mxnet/op_attr_types.h +++ b/include/mxnet/op_attr_types.h @@ -98,7 +98,12 @@ enum class ExecType { * In current implementation, copy operator is specially handled by executor. * This flag is used for special case treatment and future extension of different copy ops. */ - kCrossDeviceCopy + kCrossDeviceCopy, + /*! + * A subgraph execution should happen in the main thread, instead of + * in the execution engine. + */ + kSubgraphExec, }; /*! \brief the dispatch mode of the operator */ diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 831b5f900237..ae05fe478e68 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -1614,6 +1614,9 @@ void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) { CHECK_EQ(opnode.exec->in_array.size(), 1U); CHECK_EQ(opnode.exec->out_array.size(), 1U); CopyFromTo(opnode.exec->in_array[0], &(opnode.exec->out_array[0])); + } else if (opnode.exec->exec_type() == ExecType::kSubgraphExec) { + // If the node contains a subgraph, we can't execute it in the engine. + opnode.exec->Run(opnode.exec->op_ctx.run_ctx, false); } else if (opnode.cached_opr != nullptr) { bool profiling = profiler::Profiler::Get()->GetState() == profiler::Profiler::kRunning; Engine::Get()->Push(opnode.cached_opr, opnode.ctx, 0, profiling); diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index 726531d02994..08ea05af0c46 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -434,7 +434,8 @@ inline void PushFComputeEx(const FComputeEx& fn, } }; - if (exec_type == ExecType::kCrossDeviceCopy) { + if (exec_type == ExecType::kCrossDeviceCopy + || exec_type == ExecType::kSubgraphExec) { run(RunContext{ctx, nullptr}); } else { CHECK(exec_type == ExecType::kSync); @@ -475,12 +476,18 @@ inline void PushOperator(const OpStatePtr& state, InvalidateOutputs(outputs, req); #endif fcompute_ex(state, opctx, inputs, req, outputs); - if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync) { + if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync + && rctx.get_stream()) { rctx.get_stream()->Wait(); } }; - if (exec_type == ExecType::kSync) { + // For operators with subgraphs, we need to invoke them in the main thread + // instead of the threaded engine. + if (exec_type == ExecType::kSubgraphExec) { + RunContext rctx{ctx, nullptr}; + run(rctx, engine::CallbackOnComplete()); + } else if (exec_type == ExecType::kSync) { Engine::Get()->PushSync( [=](RunContext rctx) { run(rctx, engine::CallbackOnComplete()); }, ctx, read_vars, write_vars, FnProperty::kNormal, 0, @@ -519,12 +526,16 @@ inline void PushOperator(const OpStatePtr& state, fcompute(state, opctx, input_blobs, tmp_req, output_blobs); // post-fcompute fallback, cast to original storage type, if necessary CastNonDefaultStorage(post_temp_src, post_temp_dst, opctx, is_gpu); - if (is_gpu && exec_type == ExecType::kSync) { + if (is_gpu && exec_type == ExecType::kSync + && rctx.get_stream()) { rctx.get_stream()->Wait(); } }; - if (exec_type == ExecType::kSync) { + if (exec_type == ExecType::kSubgraphExec) { + RunContext rctx{ctx, nullptr}; + run(rctx, engine::CallbackOnComplete()); + } else if (exec_type == ExecType::kSync) { Engine::Get()->PushSync( [=](RunContext rctx) { run(rctx, engine::CallbackOnComplete()); diff --git a/src/operator/subgraph/subgraph_op.cc b/src/operator/subgraph/subgraph_op.cc index 48b8bc5ba400..d30badbbc9d8 100644 --- a/src/operator/subgraph/subgraph_op.cc +++ b/src/operator/subgraph/subgraph_op.cc @@ -297,6 +297,9 @@ NNVM_REGISTER_OP(_subgraph_op) .set_attr("FInferStorageType", SubgraphOpStorageType) .set_attr("FStatefulComputeEx", SubgraphOpForward) .set_attr("key_var_num_args", "num_args") +.set_attr("FExecType", [](const NodeAttrs& attrs) { + return ExecType::kSubgraphExec; + }) .add_argument("data", "NDArray-or-Symbol[]", "input data list"); } // namespace op