diff --git a/include/mxnet/base.h b/include/mxnet/base.h index 92d9c2699d63..eff234ac68ff 100644 --- a/include/mxnet/base.h +++ b/include/mxnet/base.h @@ -261,6 +261,10 @@ struct RunContext { * \brief the stream of the device, can be NULL or Stream* in GPU mode */ void *stream; + /*! + * \brief indicator of whether this execution is run in bulk mode + */ + bool is_bulk; /*! * \brief get mshadow stream from Context * \return the mshadow stream diff --git a/src/engine/stream_manager.h b/src/engine/stream_manager.h index 516e04bf5e82..8d44d9c15531 100644 --- a/src/engine/stream_manager.h +++ b/src/engine/stream_manager.h @@ -67,7 +67,7 @@ RunContext StreamManager::GetRunContext( RunContext ret; switch (ctx.dev_mask()) { case cpu::kDevMask: - ret = RunContext{ctx, nullptr}; + ret = RunContext{ctx, nullptr, false}; break; case gpu::kDevMask: { #if MXNET_USE_CUDA @@ -85,7 +85,9 @@ RunContext StreamManager::GetRunContext( use_counter = counter; counter = (counter + 1) % kStreams; } - ret = RunContext{ctx, gpu_streams_.at(ctx.dev_id).at(use_counter)}; + ret = RunContext{ctx, + gpu_streams_.at(ctx.dev_id).at(use_counter), + false}; break; #else LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; @@ -103,7 +105,7 @@ RunContext StreamManager::GetIORunContext( RunContext ret; switch (ctx.dev_mask()) { case cpu::kDevMask: - ret = RunContext{ctx, nullptr}; + ret = RunContext{ctx, nullptr, false}; break; case gpu::kDevMask: { #if MXNET_USE_CUDA @@ -114,7 +116,7 @@ RunContext StreamManager::GetIORunContext( gpu_io_streams_.at(ctx.dev_id) = mshadow::NewStream(false, false, ctx.dev_id); } } - ret = RunContext{ctx, gpu_io_streams_.at(ctx.dev_id)}; + ret = RunContext{ctx, gpu_io_streams_.at(ctx.dev_id), false}; break; #else LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h index ccfd09d64ea8..8be8ee8f2184 100644 --- a/src/engine/threaded_engine.h +++ b/src/engine/threaded_engine.h @@ -498,7 +498,13 @@ class ThreadedEngine : public Engine { DeduplicateVarHandle(&bulk_status.const_vars, &bulk_status.mutable_vars); SyncFn fn = std::move(bulk_status.fn); this->PushAsync([fn](RunContext ctx, CallbackOnComplete on_complete) { + ctx.is_bulk = true; fn(ctx); + ctx.is_bulk = false; + bool is_gpu = ctx.ctx.dev_mask() == gpu::kDevMask; + if (is_gpu) { + ctx.get_stream()->Wait(); + } on_complete(); }, bulk_status.ctx, bulk_status.const_vars, bulk_status.mutable_vars, FnProperty::kNormal, 0, "ImperativeBulk"); diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index f4047d1bd975..58ec4e65b846 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -583,14 +583,11 @@ void CachedOp::StaticInitExec( } size_t bulk_size = idx.num_nodes(); - std::unordered_set excludes; if (recording || keep_fwd) { bulk_size = keep_fwd ? config_.backward_bulk_size : config_.forward_bulk_size; - for (const auto& i : idx.outputs()) excludes.insert(idx.entry_id(i)); - for (const auto& i : idx.input_nodes()) excludes.insert(idx.entry_id(i, 0)); } - CreateEngineOpSeg(idx, default_ctx, start_nid, end_nid, bulk_size, excludes, + CreateEngineOpSeg(idx, default_ctx, start_nid, end_nid, bulk_size, state.execs, skip_plus_node, &state.opr_segs); } diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index 4b0d13167356..98f6c8ff10e5 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -406,7 +406,7 @@ inline void PushFCompute(const FCompute& fn, fn(attrs, opctx, input_blobs, tmp_req, output_blobs); // post-fcompute fallback, cast to original storage type CastNonDefaultStorage(post_temp_src, post_temp_dst, opctx, is_gpu); - if (is_gpu) { + if (is_gpu && !rctx.is_bulk) { rctx.get_stream()->Wait(); } }, ctx, read_vars, write_vars, FnProperty::kNormal, @@ -928,7 +928,6 @@ inline void CreateEngineOpSeg( const size_t start_nid, const size_t end_nid, const size_t bulk_size, - const std::unordered_set& excludes, const std::vector >& execs, const std::vector skip_plus_node, std::vector *opr_segs) { @@ -944,13 +943,6 @@ inline void CreateEngineOpSeg( // Stop at async nodes and invalid node (due to input/output is not allocated) bool stop = is_async || !valid || seg_execs.size() >= bulk_size; - for (size_t i = 0; i < node.inputs.size() && !stop; ++i) { - if (excludes.count(idx.entry_id(node.inputs[i]))) stop = true; - } - auto num_outputs = node.source->num_outputs(); - for (size_t i = 0; i < num_outputs && !stop; ++i) { - if (excludes.count(idx.entry_id(nid, i))) stop = true; - } // Create opr segment for previous nodes. if (stop && nid > seg_start) {